xref: /aosp_15_r20/external/armnn/src/armnnSerializer/test/SerializerTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2020-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "../Serializer.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "SerializerTestUtils.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/LstmParams.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/QuantizedLstmParams.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <random>
18*89c4ff92SAndroid Build Coastguard Worker #include <vector>
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker using armnnDeserializer::IDeserializer;
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("SerializerTests")
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeAddition")
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("addition");
30*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo tensorInfo({1, 2, 3}, armnn::DataType::Float32);
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
33*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
34*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
35*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
36*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const additionLayer = network->AddAdditionLayer(layerName.c_str());
37*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
38*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(additionLayer->GetInputSlot(0));
41*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(additionLayer->GetInputSlot(1));
42*89c4ff92SAndroid Build Coastguard Worker     additionLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(tensorInfo);
45*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(tensorInfo);
46*89c4ff92SAndroid Build Coastguard Worker     additionLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     std::string serializedNetwork = SerializeNetwork(*network);
49*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(serializedNetwork);
50*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {tensorInfo, tensorInfo}, {tensorInfo});
53*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
54*89c4ff92SAndroid Build Coastguard Worker }
55*89c4ff92SAndroid Build Coastguard Worker 
SerializeArgMinMaxTest(armnn::DataType dataType)56*89c4ff92SAndroid Build Coastguard Worker void SerializeArgMinMaxTest(armnn::DataType dataType)
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("argminmax");
59*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({1, 2, 3}, armnn::DataType::Float32);
60*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({1, 3}, dataType);
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker     armnn::ArgMinMaxDescriptor descriptor;
63*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = armnn::ArgMinMaxFunction::Max;
64*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Axis = 1;
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
67*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
68*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const argMinMaxLayer = network->AddArgMinMaxLayer(descriptor, layerName.c_str());
69*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(argMinMaxLayer->GetInputSlot(0));
72*89c4ff92SAndroid Build Coastguard Worker     argMinMaxLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
75*89c4ff92SAndroid Build Coastguard Worker     argMinMaxLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
78*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::ArgMinMaxDescriptor> verifier(layerName,
81*89c4ff92SAndroid Build Coastguard Worker                                                                          {inputInfo},
82*89c4ff92SAndroid Build Coastguard Worker                                                                          {outputInfo},
83*89c4ff92SAndroid Build Coastguard Worker                                                                          descriptor);
84*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeArgMinMaxSigned32")
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker     SerializeArgMinMaxTest(armnn::DataType::Signed32);
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeArgMinMaxSigned64")
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker     SerializeArgMinMaxTest(armnn::DataType::Signed64);
95*89c4ff92SAndroid Build Coastguard Worker }
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeBatchMatMul")
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("batchMatMul");
100*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputXInfo({2, 3, 4, 5}, armnn::DataType::Float32);
101*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputYInfo({2, 4, 3, 5}, armnn::DataType::Float32);
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({2, 3, 3, 5}, armnn::DataType::Float32);
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker     armnn::BatchMatMulDescriptor descriptor(false,
106*89c4ff92SAndroid Build Coastguard Worker                                             false,
107*89c4ff92SAndroid Build Coastguard Worker                                             false,
108*89c4ff92SAndroid Build Coastguard Worker                                             false,
109*89c4ff92SAndroid Build Coastguard Worker                                             armnn::DataLayout::NHWC,
110*89c4ff92SAndroid Build Coastguard Worker                                             armnn::DataLayout::NHWC);
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
113*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputXLayer = network->AddInputLayer(0);
114*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputYLayer = network->AddInputLayer(1);
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const batchMatMulLayer =
117*89c4ff92SAndroid Build Coastguard Worker         network->AddBatchMatMulLayer(descriptor, layerName.c_str());
118*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker     inputXLayer->GetOutputSlot(0).Connect(batchMatMulLayer->GetInputSlot(0));
121*89c4ff92SAndroid Build Coastguard Worker     inputYLayer->GetOutputSlot(0).Connect(batchMatMulLayer->GetInputSlot(1));
122*89c4ff92SAndroid Build Coastguard Worker     batchMatMulLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     inputXLayer->GetOutputSlot(0).SetTensorInfo(inputXInfo);
125*89c4ff92SAndroid Build Coastguard Worker     inputYLayer->GetOutputSlot(0).SetTensorInfo(inputYInfo);
126*89c4ff92SAndroid Build Coastguard Worker     batchMatMulLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
129*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::BatchMatMulDescriptor> verifier(layerName,
132*89c4ff92SAndroid Build Coastguard Worker                                                                            {inputXInfo, inputYInfo},
133*89c4ff92SAndroid Build Coastguard Worker                                                                            {outputInfo},
134*89c4ff92SAndroid Build Coastguard Worker                                                                            descriptor);
135*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeBatchNormalization")
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("batchNormalization");
141*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 1, 3, 3, 1 }, armnn::DataType::Float32);
142*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32);
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo meanInfo({1}, armnn::DataType::Float32, 0.0f, 0, true);
145*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo varianceInfo({1}, armnn::DataType::Float32, 0.0f, 0, true);
146*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo betaInfo({1}, armnn::DataType::Float32, 0.0f, 0, true);
147*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo gammaInfo({1}, armnn::DataType::Float32, 0.0f, 0, true);
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker     armnn::BatchNormalizationDescriptor descriptor;
150*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Eps = 0.0010000000475f;
151*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout = armnn::DataLayout::NHWC;
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> meanData({5.0});
154*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> varianceData({2.0});
155*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> betaData({1.0});
156*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> gammaData({0.0});
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::ConstTensor> constants;
159*89c4ff92SAndroid Build Coastguard Worker     constants.emplace_back(armnn::ConstTensor(meanInfo, meanData));
160*89c4ff92SAndroid Build Coastguard Worker     constants.emplace_back(armnn::ConstTensor(varianceInfo, varianceData));
161*89c4ff92SAndroid Build Coastguard Worker     constants.emplace_back(armnn::ConstTensor(betaInfo, betaData));
162*89c4ff92SAndroid Build Coastguard Worker     constants.emplace_back(armnn::ConstTensor(gammaInfo, gammaData));
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
165*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
166*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const batchNormalizationLayer =
167*89c4ff92SAndroid Build Coastguard Worker         network->AddBatchNormalizationLayer(descriptor,
168*89c4ff92SAndroid Build Coastguard Worker                                             constants[0],
169*89c4ff92SAndroid Build Coastguard Worker                                             constants[1],
170*89c4ff92SAndroid Build Coastguard Worker                                             constants[2],
171*89c4ff92SAndroid Build Coastguard Worker                                             constants[3],
172*89c4ff92SAndroid Build Coastguard Worker                                             layerName.c_str());
173*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(batchNormalizationLayer->GetInputSlot(0));
176*89c4ff92SAndroid Build Coastguard Worker     batchNormalizationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
179*89c4ff92SAndroid Build Coastguard Worker     batchNormalizationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
182*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptorAndConstants<armnn::BatchNormalizationDescriptor> verifier(
185*89c4ff92SAndroid Build Coastguard Worker         layerName, {inputInfo}, {outputInfo}, descriptor, constants);
186*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
187*89c4ff92SAndroid Build Coastguard Worker }
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeBatchToSpaceNd")
190*89c4ff92SAndroid Build Coastguard Worker {
191*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("spaceToBatchNd");
192*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({4, 1, 2, 2}, armnn::DataType::Float32);
193*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({1, 1, 4, 4}, armnn::DataType::Float32);
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker     armnn::BatchToSpaceNdDescriptor desc;
196*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NCHW;
197*89c4ff92SAndroid Build Coastguard Worker     desc.m_BlockShape = {2, 2};
198*89c4ff92SAndroid Build Coastguard Worker     desc.m_Crops = {{0, 0}, {0, 0}};
199*89c4ff92SAndroid Build Coastguard Worker 
200*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
201*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
202*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const batchToSpaceNdLayer = network->AddBatchToSpaceNdLayer(desc, layerName.c_str());
203*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
204*89c4ff92SAndroid Build Coastguard Worker 
205*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(batchToSpaceNdLayer->GetInputSlot(0));
206*89c4ff92SAndroid Build Coastguard Worker     batchToSpaceNdLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
209*89c4ff92SAndroid Build Coastguard Worker     batchToSpaceNdLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
212*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
213*89c4ff92SAndroid Build Coastguard Worker 
214*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::BatchToSpaceNdDescriptor> verifier(layerName,
215*89c4ff92SAndroid Build Coastguard Worker                                                                               {inputInfo},
216*89c4ff92SAndroid Build Coastguard Worker                                                                               {outputInfo},
217*89c4ff92SAndroid Build Coastguard Worker                                                                               desc);
218*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker 
221*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeCast")
222*89c4ff92SAndroid Build Coastguard Worker {
223*89c4ff92SAndroid Build Coastguard Worker         const std::string layerName("cast");
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker         const armnn::TensorShape shape{1, 5, 2, 3};
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker         const armnn::TensorInfo inputInfo  = armnn::TensorInfo(shape, armnn::DataType::Signed32);
228*89c4ff92SAndroid Build Coastguard Worker         const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Float32);
229*89c4ff92SAndroid Build Coastguard Worker 
230*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network = armnn::INetwork::Create();
231*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* inputLayer      = network->AddInputLayer(0);
232*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* castLayer       = network->AddCastLayer(layerName.c_str());
233*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* outputLayer     = network->AddOutputLayer(0);
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker         inputLayer->GetOutputSlot(0).Connect(castLayer->GetInputSlot(0));
236*89c4ff92SAndroid Build Coastguard Worker         castLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
237*89c4ff92SAndroid Build Coastguard Worker 
238*89c4ff92SAndroid Build Coastguard Worker         inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
239*89c4ff92SAndroid Build Coastguard Worker         castLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
240*89c4ff92SAndroid Build Coastguard Worker 
241*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
242*89c4ff92SAndroid Build Coastguard Worker         CHECK(deserializedNetwork);
243*89c4ff92SAndroid Build Coastguard Worker 
244*89c4ff92SAndroid Build Coastguard Worker         LayerVerifierBase verifier(layerName, {inputInfo}, {outputInfo});
245*89c4ff92SAndroid Build Coastguard Worker         deserializedNetwork->ExecuteStrategy(verifier);
246*89c4ff92SAndroid Build Coastguard Worker }
247*89c4ff92SAndroid Build Coastguard Worker 
248*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeChannelShuffle")
249*89c4ff92SAndroid Build Coastguard Worker {
250*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("channelShuffle");
251*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({1, 9}, armnn::DataType::Float32);
252*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({1, 9}, armnn::DataType::Float32);
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker     armnn::ChannelShuffleDescriptor descriptor({3, 1});
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
257*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
258*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const ChannelShuffleLayer =
259*89c4ff92SAndroid Build Coastguard Worker             network->AddChannelShuffleLayer(descriptor, layerName.c_str());
260*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
261*89c4ff92SAndroid Build Coastguard Worker 
262*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(ChannelShuffleLayer->GetInputSlot(0));
263*89c4ff92SAndroid Build Coastguard Worker     ChannelShuffleLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
264*89c4ff92SAndroid Build Coastguard Worker 
265*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
266*89c4ff92SAndroid Build Coastguard Worker     ChannelShuffleLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
267*89c4ff92SAndroid Build Coastguard Worker 
268*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
269*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
270*89c4ff92SAndroid Build Coastguard Worker 
271*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::ChannelShuffleDescriptor> verifier(
272*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {outputInfo}, descriptor);
273*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
274*89c4ff92SAndroid Build Coastguard Worker }
275*89c4ff92SAndroid Build Coastguard Worker 
276*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeComparison")
277*89c4ff92SAndroid Build Coastguard Worker {
278*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("comparison");
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape shape{2, 1, 2, 4};
281*89c4ff92SAndroid Build Coastguard Worker 
282*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo(shape, armnn::DataType::Float32);
283*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean);
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker     armnn::ComparisonDescriptor descriptor(armnn::ComparisonOperation::NotEqual);
286*89c4ff92SAndroid Build Coastguard Worker 
287*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
288*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0     = network->AddInputLayer(0);
289*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1     = network->AddInputLayer(1);
290*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const comparisonLayer = network->AddComparisonLayer(descriptor, layerName.c_str());
291*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer     = network->AddOutputLayer(0);
292*89c4ff92SAndroid Build Coastguard Worker 
293*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(comparisonLayer->GetInputSlot(0));
294*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(comparisonLayer->GetInputSlot(1));
295*89c4ff92SAndroid Build Coastguard Worker     comparisonLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
296*89c4ff92SAndroid Build Coastguard Worker 
297*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(inputInfo);
298*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(inputInfo);
299*89c4ff92SAndroid Build Coastguard Worker     comparisonLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
300*89c4ff92SAndroid Build Coastguard Worker 
301*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
302*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
303*89c4ff92SAndroid Build Coastguard Worker 
304*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::ComparisonDescriptor> verifier(layerName,
305*89c4ff92SAndroid Build Coastguard Worker                                                                           { inputInfo, inputInfo },
306*89c4ff92SAndroid Build Coastguard Worker                                                                           { outputInfo },
307*89c4ff92SAndroid Build Coastguard Worker                                                                           descriptor);
308*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
309*89c4ff92SAndroid Build Coastguard Worker }
310*89c4ff92SAndroid Build Coastguard Worker 
311*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeConstant")
312*89c4ff92SAndroid Build Coastguard Worker {
313*89c4ff92SAndroid Build Coastguard Worker     class ConstantLayerVerifier : public LayerVerifierBase
314*89c4ff92SAndroid Build Coastguard Worker     {
315*89c4ff92SAndroid Build Coastguard Worker     public:
ConstantLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const std::vector<armnn::ConstTensor> & constants)316*89c4ff92SAndroid Build Coastguard Worker         ConstantLayerVerifier(const std::string& layerName,
317*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<armnn::TensorInfo>& inputInfos,
318*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<armnn::TensorInfo>& outputInfos,
319*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<armnn::ConstTensor>& constants)
320*89c4ff92SAndroid Build Coastguard Worker             : LayerVerifierBase(layerName, inputInfos, outputInfos)
321*89c4ff92SAndroid Build Coastguard Worker             , m_Constants(constants) {}
322*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)323*89c4ff92SAndroid Build Coastguard Worker         void ExecuteStrategy(const armnn::IConnectableLayer* layer,
324*89c4ff92SAndroid Build Coastguard Worker                              const armnn::BaseDescriptor& descriptor,
325*89c4ff92SAndroid Build Coastguard Worker                              const std::vector<armnn::ConstTensor>& constants,
326*89c4ff92SAndroid Build Coastguard Worker                              const char* name,
327*89c4ff92SAndroid Build Coastguard Worker                              const armnn::LayerBindingId id = 0) override
328*89c4ff92SAndroid Build Coastguard Worker         {
329*89c4ff92SAndroid Build Coastguard Worker             armnn::IgnoreUnused(descriptor, id);
330*89c4ff92SAndroid Build Coastguard Worker 
331*89c4ff92SAndroid Build Coastguard Worker             switch (layer->GetType())
332*89c4ff92SAndroid Build Coastguard Worker             {
333*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Input: break;
334*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Output: break;
335*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Addition: break;
336*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::ElementwiseBinary: break;
337*89c4ff92SAndroid Build Coastguard Worker                 default:
338*89c4ff92SAndroid Build Coastguard Worker                 {
339*89c4ff92SAndroid Build Coastguard Worker                     this->VerifyNameAndConnections(layer, name);
340*89c4ff92SAndroid Build Coastguard Worker 
341*89c4ff92SAndroid Build Coastguard Worker                     for (std::size_t i = 0; i < constants.size(); i++)
342*89c4ff92SAndroid Build Coastguard Worker                     {
343*89c4ff92SAndroid Build Coastguard Worker                         CompareConstTensor(constants[i], m_Constants[i]);
344*89c4ff92SAndroid Build Coastguard Worker                     }
345*89c4ff92SAndroid Build Coastguard Worker                 }
346*89c4ff92SAndroid Build Coastguard Worker             }
347*89c4ff92SAndroid Build Coastguard Worker         }
348*89c4ff92SAndroid Build Coastguard Worker 
349*89c4ff92SAndroid Build Coastguard Worker     private:
350*89c4ff92SAndroid Build Coastguard Worker         const std::vector<armnn::ConstTensor> m_Constants;
351*89c4ff92SAndroid Build Coastguard Worker     };
352*89c4ff92SAndroid Build Coastguard Worker 
353*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("constant");
354*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 2, 3 }, armnn::DataType::Float32, 0.0f, 0, true);
355*89c4ff92SAndroid Build Coastguard Worker 
356*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> constantData = GenerateRandomData<float>(info.GetNumElements());
357*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor constTensor(info, constantData);
358*89c4ff92SAndroid Build Coastguard Worker 
359*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
360*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* input = network->AddInputLayer(0);
361*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* constant = network->AddConstantLayer(constTensor, layerName.c_str());
362*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
363*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* add = network->AddAdditionLayer();
364*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
365*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* output = network->AddOutputLayer(0);
366*89c4ff92SAndroid Build Coastguard Worker 
367*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(add->GetInputSlot(0));
368*89c4ff92SAndroid Build Coastguard Worker     constant->GetOutputSlot(0).Connect(add->GetInputSlot(1));
369*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
370*89c4ff92SAndroid Build Coastguard Worker 
371*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(info);
372*89c4ff92SAndroid Build Coastguard Worker     constant->GetOutputSlot(0).SetTensorInfo(info);
373*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
374*89c4ff92SAndroid Build Coastguard Worker 
375*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
376*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
377*89c4ff92SAndroid Build Coastguard Worker 
378*89c4ff92SAndroid Build Coastguard Worker     ConstantLayerVerifier verifier(layerName, {}, {info}, {constTensor});
379*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
380*89c4ff92SAndroid Build Coastguard Worker }
381*89c4ff92SAndroid Build Coastguard Worker 
382*89c4ff92SAndroid Build Coastguard Worker using Convolution2dDescriptor = armnn::Convolution2dDescriptor;
383*89c4ff92SAndroid Build Coastguard Worker class Convolution2dLayerVerifier : public LayerVerifierBaseWithDescriptor<Convolution2dDescriptor>
384*89c4ff92SAndroid Build Coastguard Worker {
385*89c4ff92SAndroid Build Coastguard Worker public:
Convolution2dLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const Convolution2dDescriptor & descriptor)386*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayerVerifier(const std::string& layerName,
387*89c4ff92SAndroid Build Coastguard Worker                         const std::vector<armnn::TensorInfo>& inputInfos,
388*89c4ff92SAndroid Build Coastguard Worker                         const std::vector<armnn::TensorInfo>& outputInfos,
389*89c4ff92SAndroid Build Coastguard Worker                         const Convolution2dDescriptor& descriptor)
390*89c4ff92SAndroid Build Coastguard Worker         : LayerVerifierBaseWithDescriptor<Convolution2dDescriptor>(layerName, inputInfos, outputInfos, descriptor) {}
391*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)392*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
393*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
394*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
395*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
396*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override
397*89c4ff92SAndroid Build Coastguard Worker     {
398*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(constants, id);
399*89c4ff92SAndroid Build Coastguard Worker         switch (layer->GetType())
400*89c4ff92SAndroid Build Coastguard Worker         {
401*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Input: break;
402*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Output: break;
403*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Constant: break;
404*89c4ff92SAndroid Build Coastguard Worker             default:
405*89c4ff92SAndroid Build Coastguard Worker             {
406*89c4ff92SAndroid Build Coastguard Worker                 VerifyNameAndConnections(layer, name);
407*89c4ff92SAndroid Build Coastguard Worker                 const Convolution2dDescriptor& layerDescriptor =
408*89c4ff92SAndroid Build Coastguard Worker                         static_cast<const Convolution2dDescriptor&>(descriptor);
409*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_BiasEnabled == m_Descriptor.m_BiasEnabled);
410*89c4ff92SAndroid Build Coastguard Worker             }
411*89c4ff92SAndroid Build Coastguard Worker         }
412*89c4ff92SAndroid Build Coastguard Worker     }
413*89c4ff92SAndroid Build Coastguard Worker };
414*89c4ff92SAndroid Build Coastguard Worker 
415*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeConvolution2d")
416*89c4ff92SAndroid Build Coastguard Worker {
417*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("convolution2d");
418*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 1, 5, 5, 1 }, armnn::DataType::Float32);
419*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32);
420*89c4ff92SAndroid Build Coastguard Worker 
421*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo weightsInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32, 0.0f, 0, true);
422*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo biasesInfo ({ 1 }, armnn::DataType::Float32, 0.0f, 0, true);
423*89c4ff92SAndroid Build Coastguard Worker 
424*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
425*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightsInfo, weightsData);
426*89c4ff92SAndroid Build Coastguard Worker 
427*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData = GenerateRandomData<float>(biasesInfo.GetNumElements());
428*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasesInfo, biasesData);
429*89c4ff92SAndroid Build Coastguard Worker 
430*89c4ff92SAndroid Build Coastguard Worker     armnn::Convolution2dDescriptor descriptor;
431*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 1;
432*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 1;
433*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 1;
434*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 1;
435*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 2;
436*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 2;
437*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationX   = 2;
438*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationY   = 2;
439*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
440*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = armnn::DataLayout::NHWC;
441*89c4ff92SAndroid Build Coastguard Worker 
442*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
443*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer  = network->AddInputLayer(0);
444*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsLayer = network->AddConstantLayer(weights, "weights");
445*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const biasLayer = network->AddConstantLayer(biases, "bias");
446*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const convLayer = network->AddConvolution2dLayer(descriptor, layerName.c_str());
447*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
448*89c4ff92SAndroid Build Coastguard Worker 
449*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
450*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
451*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
452*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
453*89c4ff92SAndroid Build Coastguard Worker 
454*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
455*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
456*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).SetTensorInfo(biasesInfo);
457*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
458*89c4ff92SAndroid Build Coastguard Worker 
459*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
460*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
461*89c4ff92SAndroid Build Coastguard Worker 
462*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayerVerifier verifier(layerName, {inputInfo, weightsInfo, biasesInfo}, {outputInfo}, descriptor);
463*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
464*89c4ff92SAndroid Build Coastguard Worker }
465*89c4ff92SAndroid Build Coastguard Worker 
466*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeConvolution2dWithPerAxisParams")
467*89c4ff92SAndroid Build Coastguard Worker {
468*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
469*89c4ff92SAndroid Build Coastguard Worker 
470*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("convolution2dWithPerAxis");
471*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({ 1, 3, 1, 2 }, DataType::QAsymmU8, 0.55f, 128);
472*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({ 1, 3, 1, 3 }, DataType::QAsymmU8, 0.75f, 128);
473*89c4ff92SAndroid Build Coastguard Worker 
474*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> quantScales{ 0.75f, 0.65f, 0.85f };
475*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int quantDimension = 0;
476*89c4ff92SAndroid Build Coastguard Worker 
477*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo kernelInfo({ 3, 1, 1, 2 }, DataType::QSymmS8, quantScales, quantDimension, true);
478*89c4ff92SAndroid Build Coastguard Worker 
479*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> biasQuantScales{ 0.25f, 0.50f, 0.75f };
480*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo biasInfo({ 3 }, DataType::Signed32, biasQuantScales, quantDimension, true);
481*89c4ff92SAndroid Build Coastguard Worker 
482*89c4ff92SAndroid Build Coastguard Worker     std::vector<int8_t> kernelData = GenerateRandomData<int8_t>(kernelInfo.GetNumElements());
483*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(kernelInfo, kernelData);
484*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasData = GenerateRandomData<int32_t>(biasInfo.GetNumElements());
485*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasInfo, biasData);
486*89c4ff92SAndroid Build Coastguard Worker 
487*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor descriptor;
488*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 1;
489*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 1;
490*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 0;
491*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 0;
492*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 0;
493*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 0;
494*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
495*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = armnn::DataLayout::NHWC;
496*89c4ff92SAndroid Build Coastguard Worker 
497*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
498*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer  = network->AddInputLayer(0);
499*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsLayer = network->AddConstantLayer(weights, "weights");
500*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const biasLayer = network->AddConstantLayer(weights, "bias");
501*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const convLayer = network->AddConvolution2dLayer(descriptor, layerName.c_str());
502*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
503*89c4ff92SAndroid Build Coastguard Worker 
504*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
505*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
506*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
507*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
508*89c4ff92SAndroid Build Coastguard Worker 
509*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
510*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(kernelInfo);
511*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
512*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
513*89c4ff92SAndroid Build Coastguard Worker 
514*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
515*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
516*89c4ff92SAndroid Build Coastguard Worker 
517*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayerVerifier verifier(layerName, {inputInfo, kernelInfo, biasInfo}, {outputInfo}, descriptor);
518*89c4ff92SAndroid Build Coastguard Worker 
519*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
520*89c4ff92SAndroid Build Coastguard Worker }
521*89c4ff92SAndroid Build Coastguard Worker 
522*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeConvolution2dWeightsAndBiasesAsConstantLayers")
523*89c4ff92SAndroid Build Coastguard Worker {
524*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("convolution2d");
525*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 1, 5, 5, 1 }, armnn::DataType::Float32);
526*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32);
527*89c4ff92SAndroid Build Coastguard Worker 
528*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo weightsInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32, 0.0f, 0, true);
529*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo biasesInfo ({ 1 }, armnn::DataType::Float32, 0.0f, 0, true);
530*89c4ff92SAndroid Build Coastguard Worker 
531*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
532*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightsInfo, weightsData);
533*89c4ff92SAndroid Build Coastguard Worker 
534*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData = GenerateRandomData<float>(biasesInfo.GetNumElements());
535*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasesInfo, biasesData);
536*89c4ff92SAndroid Build Coastguard Worker 
537*89c4ff92SAndroid Build Coastguard Worker     armnn::Convolution2dDescriptor descriptor;
538*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 1;
539*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 1;
540*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 1;
541*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 1;
542*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 2;
543*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 2;
544*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationX   = 2;
545*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationY   = 2;
546*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
547*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = armnn::DataLayout::NHWC;
548*89c4ff92SAndroid Build Coastguard Worker 
549*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
550*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer  = network->AddInputLayer(0);
551*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsLayer = network->AddConstantLayer(weights, "Weights");
552*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const biasesLayer = network->AddConstantLayer(biases, "Biases");
553*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const convLayer   = network->AddConvolution2dLayer(descriptor,
554*89c4ff92SAndroid Build Coastguard Worker                                            layerName.c_str());
555*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
556*89c4ff92SAndroid Build Coastguard Worker 
557*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
558*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
559*89c4ff92SAndroid Build Coastguard Worker     biasesLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
560*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
561*89c4ff92SAndroid Build Coastguard Worker 
562*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
563*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
564*89c4ff92SAndroid Build Coastguard Worker     biasesLayer->GetOutputSlot(0).SetTensorInfo(biasesInfo);
565*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
566*89c4ff92SAndroid Build Coastguard Worker 
567*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
568*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
569*89c4ff92SAndroid Build Coastguard Worker 
570*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::ConstTensor>& constants {weights, biases};
571*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptorAndConstants<armnn::Convolution2dDescriptor> verifier(
572*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo, weightsInfo, biasesInfo}, {outputInfo}, descriptor, constants);
573*89c4ff92SAndroid Build Coastguard Worker 
574*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
575*89c4ff92SAndroid Build Coastguard Worker }
576*89c4ff92SAndroid Build Coastguard Worker 
577*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeConvolution3d")
578*89c4ff92SAndroid Build Coastguard Worker {
579*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("convolution3d");
580*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 1, 5, 5, 5, 1 }, armnn::DataType::Float32);
581*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1, 2, 2, 2, 1 }, armnn::DataType::Float32);
582*89c4ff92SAndroid Build Coastguard Worker 
583*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo weightsInfo({ 3, 3, 3, 1, 1 }, armnn::DataType::Float32, 0.0f, 0, true);
584*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo biasesInfo ({ 1 }, armnn::DataType::Float32, 0.0f, 0, true);
585*89c4ff92SAndroid Build Coastguard Worker 
586*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
587*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightsInfo, weightsData);
588*89c4ff92SAndroid Build Coastguard Worker 
589*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData = GenerateRandomData<float>(biasesInfo.GetNumElements());
590*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasesInfo, biasesData);
591*89c4ff92SAndroid Build Coastguard Worker 
592*89c4ff92SAndroid Build Coastguard Worker     armnn::Convolution3dDescriptor descriptor;
593*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 0;
594*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 0;
595*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 0;
596*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 0;
597*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadFront    = 0;
598*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBack     = 0;
599*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationX   = 1;
600*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationY   = 1;
601*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationZ   = 1;
602*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 2;
603*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 2;
604*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideZ     = 2;
605*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
606*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = armnn::DataLayout::NDHWC;
607*89c4ff92SAndroid Build Coastguard Worker 
608*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
609*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer  = network->AddInputLayer(0);
610*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsLayer = network->AddConstantLayer(weights, "Weights");
611*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const biasesLayer = network->AddConstantLayer(biases, "Biases");
612*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const convLayer   = network->AddConvolution3dLayer(descriptor, layerName.c_str());
613*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
614*89c4ff92SAndroid Build Coastguard Worker 
615*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
616*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
617*89c4ff92SAndroid Build Coastguard Worker     biasesLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
618*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
619*89c4ff92SAndroid Build Coastguard Worker 
620*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
621*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
622*89c4ff92SAndroid Build Coastguard Worker     biasesLayer->GetOutputSlot(0).SetTensorInfo(biasesInfo);
623*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
624*89c4ff92SAndroid Build Coastguard Worker 
625*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
626*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
627*89c4ff92SAndroid Build Coastguard Worker 
628*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::Convolution3dDescriptor> verifier(
629*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo, weightsInfo, biasesInfo}, {outputInfo}, descriptor);
630*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
631*89c4ff92SAndroid Build Coastguard Worker }
632*89c4ff92SAndroid Build Coastguard Worker 
633*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeDepthToSpace")
634*89c4ff92SAndroid Build Coastguard Worker {
635*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("depthToSpace");
636*89c4ff92SAndroid Build Coastguard Worker 
637*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 1,  8, 4, 12 }, armnn::DataType::Float32);
638*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1, 16, 8,  3 }, armnn::DataType::Float32);
639*89c4ff92SAndroid Build Coastguard Worker 
640*89c4ff92SAndroid Build Coastguard Worker     armnn::DepthToSpaceDescriptor desc;
641*89c4ff92SAndroid Build Coastguard Worker     desc.m_BlockSize  = 2;
642*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
643*89c4ff92SAndroid Build Coastguard Worker 
644*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
645*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer        = network->AddInputLayer(0);
646*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const depthToSpaceLayer = network->AddDepthToSpaceLayer(desc, layerName.c_str());
647*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer       = network->AddOutputLayer(0);
648*89c4ff92SAndroid Build Coastguard Worker 
649*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(depthToSpaceLayer->GetInputSlot(0));
650*89c4ff92SAndroid Build Coastguard Worker     depthToSpaceLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
651*89c4ff92SAndroid Build Coastguard Worker 
652*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
653*89c4ff92SAndroid Build Coastguard Worker     depthToSpaceLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
654*89c4ff92SAndroid Build Coastguard Worker 
655*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
656*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
657*89c4ff92SAndroid Build Coastguard Worker 
658*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::DepthToSpaceDescriptor> verifier(layerName, {inputInfo}, {outputInfo}, desc);
659*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
660*89c4ff92SAndroid Build Coastguard Worker }
661*89c4ff92SAndroid Build Coastguard Worker 
662*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeDepthwiseConvolution2d")
663*89c4ff92SAndroid Build Coastguard Worker {
664*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("depwiseConvolution2d");
665*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 1, 5, 5, 3 }, armnn::DataType::Float32);
666*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1, 3, 3, 3 }, armnn::DataType::Float32);
667*89c4ff92SAndroid Build Coastguard Worker 
668*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo weightsInfo({ 1, 3, 3, 3 }, armnn::DataType::Float32, 0.0f, 0, true);
669*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo biasesInfo ({ 3 }, armnn::DataType::Float32, 0.0f, 0, true);
670*89c4ff92SAndroid Build Coastguard Worker 
671*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
672*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightsInfo, weightsData);
673*89c4ff92SAndroid Build Coastguard Worker 
674*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasesData = GenerateRandomData<int32_t>(biasesInfo.GetNumElements());
675*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasesInfo, biasesData);
676*89c4ff92SAndroid Build Coastguard Worker 
677*89c4ff92SAndroid Build Coastguard Worker     armnn::DepthwiseConvolution2dDescriptor descriptor;
678*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 1;
679*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 1;
680*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 1;
681*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 1;
682*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 2;
683*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 2;
684*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationX   = 2;
685*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationY   = 2;
686*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
687*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = armnn::DataLayout::NHWC;
688*89c4ff92SAndroid Build Coastguard Worker 
689*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
690*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
691*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const depthwiseConvLayer = network->AddDepthwiseConvolution2dLayer(descriptor,
692*89c4ff92SAndroid Build Coastguard Worker                                                                                                  layerName.c_str());
693*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
694*89c4ff92SAndroid Build Coastguard Worker 
695*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(depthwiseConvLayer->GetInputSlot(0));
696*89c4ff92SAndroid Build Coastguard Worker     depthwiseConvLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
697*89c4ff92SAndroid Build Coastguard Worker 
698*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
699*89c4ff92SAndroid Build Coastguard Worker     depthwiseConvLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
700*89c4ff92SAndroid Build Coastguard Worker 
701*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsLayer = network->AddConstantLayer(weights);
702*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(depthwiseConvLayer->GetInputSlot(1u));
703*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weights.GetInfo());
704*89c4ff92SAndroid Build Coastguard Worker 
705*89c4ff92SAndroid Build Coastguard Worker      armnn::IConnectableLayer* const biasLayer = network->AddConstantLayer(biases);
706*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).Connect(depthwiseConvLayer->GetInputSlot(2u));
707*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).SetTensorInfo(biases.GetInfo());
708*89c4ff92SAndroid Build Coastguard Worker 
709*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
710*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
711*89c4ff92SAndroid Build Coastguard Worker 
712*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::ConstTensor>& constants {weights, biases};
713*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptorAndConstants<armnn::DepthwiseConvolution2dDescriptor> verifier(
714*89c4ff92SAndroid Build Coastguard Worker         layerName, {inputInfo, weightsInfo, biasesInfo}, {outputInfo}, descriptor, constants);
715*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
716*89c4ff92SAndroid Build Coastguard Worker }
717*89c4ff92SAndroid Build Coastguard Worker 
718*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeDepthwiseConvolution2dWithPerAxisParams")
719*89c4ff92SAndroid Build Coastguard Worker {
720*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
721*89c4ff92SAndroid Build Coastguard Worker 
722*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("depwiseConvolution2dWithPerAxis");
723*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({ 1, 3, 3, 2 }, DataType::QAsymmU8, 0.55f, 128);
724*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({ 1, 2, 2, 4 }, DataType::QAsymmU8, 0.75f, 128);
725*89c4ff92SAndroid Build Coastguard Worker 
726*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> quantScales{ 0.75f, 0.80f, 0.90f, 0.95f };
727*89c4ff92SAndroid Build Coastguard Worker     const unsigned int quantDimension = 0;
728*89c4ff92SAndroid Build Coastguard Worker     TensorInfo kernelInfo({ 2, 2, 2, 2 }, DataType::QSymmS8, quantScales, quantDimension, true);
729*89c4ff92SAndroid Build Coastguard Worker 
730*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> biasQuantScales{ 0.25f, 0.35f, 0.45f, 0.55f };
731*89c4ff92SAndroid Build Coastguard Worker     constexpr unsigned int biasQuantDimension = 0;
732*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo({ 4 }, DataType::Signed32, biasQuantScales, biasQuantDimension, true);
733*89c4ff92SAndroid Build Coastguard Worker 
734*89c4ff92SAndroid Build Coastguard Worker     std::vector<int8_t> kernelData = GenerateRandomData<int8_t>(kernelInfo.GetNumElements());
735*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(kernelInfo, kernelData);
736*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasData = GenerateRandomData<int32_t>(biasInfo.GetNumElements());
737*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasInfo, biasData);
738*89c4ff92SAndroid Build Coastguard Worker 
739*89c4ff92SAndroid Build Coastguard Worker     DepthwiseConvolution2dDescriptor descriptor;
740*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 1;
741*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 1;
742*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 0;
743*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 0;
744*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 0;
745*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 0;
746*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationX   = 1;
747*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationY   = 1;
748*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
749*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = armnn::DataLayout::NHWC;
750*89c4ff92SAndroid Build Coastguard Worker 
751*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
752*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
753*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const depthwiseConvLayer = network->AddDepthwiseConvolution2dLayer(descriptor,
754*89c4ff92SAndroid Build Coastguard Worker                                                                                                  layerName.c_str());
755*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
756*89c4ff92SAndroid Build Coastguard Worker 
757*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(depthwiseConvLayer->GetInputSlot(0));
758*89c4ff92SAndroid Build Coastguard Worker     depthwiseConvLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
759*89c4ff92SAndroid Build Coastguard Worker 
760*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
761*89c4ff92SAndroid Build Coastguard Worker     depthwiseConvLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
762*89c4ff92SAndroid Build Coastguard Worker 
763*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsLayer = network->AddConstantLayer(weights);
764*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(depthwiseConvLayer->GetInputSlot(1u));
765*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weights.GetInfo());
766*89c4ff92SAndroid Build Coastguard Worker 
767*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const biasLayer = network->AddConstantLayer(biases);
768*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).Connect(depthwiseConvLayer->GetInputSlot(2u));
769*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).SetTensorInfo(biases.GetInfo());
770*89c4ff92SAndroid Build Coastguard Worker 
771*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
772*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
773*89c4ff92SAndroid Build Coastguard Worker 
774*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::ConstTensor>& constants {weights, biases};
775*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptorAndConstants<armnn::DepthwiseConvolution2dDescriptor> verifier(
776*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo, kernelInfo, biasInfo}, {outputInfo}, descriptor, constants);
777*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
778*89c4ff92SAndroid Build Coastguard Worker }
779*89c4ff92SAndroid Build Coastguard Worker 
780*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeDepthwiseConvolution2dWeightsAndBiasesAsConstantLayers")
781*89c4ff92SAndroid Build Coastguard Worker {
782*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("depthwiseConvolution2d");
783*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 1, 5, 5, 1 }, armnn::DataType::Float32);
784*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32);
785*89c4ff92SAndroid Build Coastguard Worker 
786*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo weightsInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32, 0.0f, 0, true);
787*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo biasesInfo ({ 1 }, armnn::DataType::Float32, 0.0f, 0, true);
788*89c4ff92SAndroid Build Coastguard Worker 
789*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
790*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightsInfo, weightsData);
791*89c4ff92SAndroid Build Coastguard Worker 
792*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData = GenerateRandomData<float>(biasesInfo.GetNumElements());
793*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasesInfo, biasesData);
794*89c4ff92SAndroid Build Coastguard Worker 
795*89c4ff92SAndroid Build Coastguard Worker     armnn::DepthwiseConvolution2dDescriptor descriptor;
796*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 1;
797*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 1;
798*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 1;
799*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 1;
800*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 2;
801*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 2;
802*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationX   = 2;
803*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DilationY   = 2;
804*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
805*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = armnn::DataLayout::NHWC;
806*89c4ff92SAndroid Build Coastguard Worker 
807*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
808*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer  = network->AddInputLayer(0);
809*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsLayer = network->AddConstantLayer(weights, "Weights");
810*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const biasesLayer = network->AddConstantLayer(biases, "Biases");
811*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const convLayer   = network->AddDepthwiseConvolution2dLayer(descriptor,
812*89c4ff92SAndroid Build Coastguard Worker                                                                                           layerName.c_str());
813*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
814*89c4ff92SAndroid Build Coastguard Worker 
815*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
816*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
817*89c4ff92SAndroid Build Coastguard Worker     biasesLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
818*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
819*89c4ff92SAndroid Build Coastguard Worker 
820*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
821*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
822*89c4ff92SAndroid Build Coastguard Worker     biasesLayer->GetOutputSlot(0).SetTensorInfo(biasesInfo);
823*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
824*89c4ff92SAndroid Build Coastguard Worker 
825*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
826*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
827*89c4ff92SAndroid Build Coastguard Worker 
828*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::ConstTensor>& constants {weights, biases};
829*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptorAndConstants<armnn::DepthwiseConvolution2dDescriptor> verifier(
830*89c4ff92SAndroid Build Coastguard Worker         layerName, {inputInfo, weightsInfo, biasesInfo}, {outputInfo}, descriptor, constants);
831*89c4ff92SAndroid Build Coastguard Worker 
832*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
833*89c4ff92SAndroid Build Coastguard Worker }
834*89c4ff92SAndroid Build Coastguard Worker 
835*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeDequantize")
836*89c4ff92SAndroid Build Coastguard Worker {
837*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("dequantize");
838*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({ 1, 5, 2, 3 }, armnn::DataType::QAsymmU8, 0.5f, 1);
839*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1, 5, 2, 3 }, armnn::DataType::Float32);
840*89c4ff92SAndroid Build Coastguard Worker 
841*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
842*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
843*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const dequantizeLayer = network->AddDequantizeLayer(layerName.c_str());
844*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
845*89c4ff92SAndroid Build Coastguard Worker 
846*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(dequantizeLayer->GetInputSlot(0));
847*89c4ff92SAndroid Build Coastguard Worker     dequantizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
848*89c4ff92SAndroid Build Coastguard Worker 
849*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
850*89c4ff92SAndroid Build Coastguard Worker     dequantizeLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
851*89c4ff92SAndroid Build Coastguard Worker 
852*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
853*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
854*89c4ff92SAndroid Build Coastguard Worker 
855*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {inputInfo}, {outputInfo});
856*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
857*89c4ff92SAndroid Build Coastguard Worker }
858*89c4ff92SAndroid Build Coastguard Worker 
859*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeDeserializeDetectionPostProcess")
860*89c4ff92SAndroid Build Coastguard Worker {
861*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("detectionPostProcess");
862*89c4ff92SAndroid Build Coastguard Worker 
863*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::TensorInfo> inputInfos({
864*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo({ 1, 6, 4 }, armnn::DataType::Float32),
865*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo({ 1, 6, 3}, armnn::DataType::Float32)
866*89c4ff92SAndroid Build Coastguard Worker     });
867*89c4ff92SAndroid Build Coastguard Worker 
868*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::TensorInfo> outputInfos({
869*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo({ 1, 3, 4 }, armnn::DataType::Float32),
870*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo({ 1, 3 }, armnn::DataType::Float32),
871*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo({ 1, 3 }, armnn::DataType::Float32),
872*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo({ 1 }, armnn::DataType::Float32)
873*89c4ff92SAndroid Build Coastguard Worker     });
874*89c4ff92SAndroid Build Coastguard Worker 
875*89c4ff92SAndroid Build Coastguard Worker     armnn::DetectionPostProcessDescriptor descriptor;
876*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_UseRegularNms = true;
877*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_MaxDetections = 3;
878*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_MaxClassesPerDetection = 1;
879*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DetectionsPerClass =1;
880*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_NmsScoreThreshold = 0.0;
881*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_NmsIouThreshold = 0.5;
882*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_NumClasses = 2;
883*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ScaleY = 10.0;
884*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ScaleX = 10.0;
885*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ScaleH = 5.0;
886*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ScaleW = 5.0;
887*89c4ff92SAndroid Build Coastguard Worker 
888*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32, 0.0f, 0, true);
889*89c4ff92SAndroid Build Coastguard Worker     const std::vector<float> anchorsData({
890*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 1.0f, 1.0f,
891*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 1.0f, 1.0f,
892*89c4ff92SAndroid Build Coastguard Worker         0.5f, 0.5f, 1.0f, 1.0f,
893*89c4ff92SAndroid Build Coastguard Worker         0.5f, 10.5f, 1.0f, 1.0f,
894*89c4ff92SAndroid Build Coastguard Worker         0.5f, 10.5f, 1.0f, 1.0f,
895*89c4ff92SAndroid Build Coastguard Worker         0.5f, 100.5f, 1.0f, 1.0f
896*89c4ff92SAndroid Build Coastguard Worker     });
897*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor anchors(anchorsInfo, anchorsData);
898*89c4ff92SAndroid Build Coastguard Worker 
899*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
900*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const detectionLayer =
901*89c4ff92SAndroid Build Coastguard Worker         network->AddDetectionPostProcessLayer(descriptor, anchors, layerName.c_str());
902*89c4ff92SAndroid Build Coastguard Worker 
903*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < 2; i++)
904*89c4ff92SAndroid Build Coastguard Worker     {
905*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(static_cast<int>(i));
906*89c4ff92SAndroid Build Coastguard Worker         inputLayer->GetOutputSlot(0).Connect(detectionLayer->GetInputSlot(i));
907*89c4ff92SAndroid Build Coastguard Worker         inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfos[i]);
908*89c4ff92SAndroid Build Coastguard Worker     }
909*89c4ff92SAndroid Build Coastguard Worker 
910*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < 4; i++)
911*89c4ff92SAndroid Build Coastguard Worker     {
912*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(static_cast<int>(i));
913*89c4ff92SAndroid Build Coastguard Worker         detectionLayer->GetOutputSlot(i).Connect(outputLayer->GetInputSlot(0));
914*89c4ff92SAndroid Build Coastguard Worker         detectionLayer->GetOutputSlot(i).SetTensorInfo(outputInfos[i]);
915*89c4ff92SAndroid Build Coastguard Worker     }
916*89c4ff92SAndroid Build Coastguard Worker 
917*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
918*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
919*89c4ff92SAndroid Build Coastguard Worker 
920*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::ConstTensor>& constants {anchors};
921*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptorAndConstants<armnn::DetectionPostProcessDescriptor> verifier(
922*89c4ff92SAndroid Build Coastguard Worker             layerName, inputInfos, outputInfos, descriptor, constants);
923*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
924*89c4ff92SAndroid Build Coastguard Worker }
925*89c4ff92SAndroid Build Coastguard Worker 
926*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeDivision")
927*89c4ff92SAndroid Build Coastguard Worker {
928*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("division");
929*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 5, 2, 3 }, armnn::DataType::Float32);
930*89c4ff92SAndroid Build Coastguard Worker 
931*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
932*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
933*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
934*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
935*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const divisionLayer = network->AddDivisionLayer(layerName.c_str());
936*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
937*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
938*89c4ff92SAndroid Build Coastguard Worker 
939*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(divisionLayer->GetInputSlot(0));
940*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(divisionLayer->GetInputSlot(1));
941*89c4ff92SAndroid Build Coastguard Worker     divisionLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
942*89c4ff92SAndroid Build Coastguard Worker 
943*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(info);
944*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(info);
945*89c4ff92SAndroid Build Coastguard Worker     divisionLayer->GetOutputSlot(0).SetTensorInfo(info);
946*89c4ff92SAndroid Build Coastguard Worker 
947*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
948*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
949*89c4ff92SAndroid Build Coastguard Worker 
950*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {info, info}, {info});
951*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
952*89c4ff92SAndroid Build Coastguard Worker }
953*89c4ff92SAndroid Build Coastguard Worker 
954*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeDeserializeComparisonEqual")
955*89c4ff92SAndroid Build Coastguard Worker {
956*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("EqualLayer");
957*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputTensorInfo1 = armnn::TensorInfo({2, 1, 2, 4}, armnn::DataType::Float32);
958*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputTensorInfo2 = armnn::TensorInfo({2, 1, 2, 4}, armnn::DataType::Float32);
959*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputTensorInfo = armnn::TensorInfo({2, 1, 2, 4}, armnn::DataType::Boolean);
960*89c4ff92SAndroid Build Coastguard Worker 
961*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
962*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(0);
963*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer2 = network->AddInputLayer(1);
964*89c4ff92SAndroid Build Coastguard Worker     armnn::ComparisonDescriptor equalDescriptor(armnn::ComparisonOperation::Equal);
965*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const equalLayer = network->AddComparisonLayer(equalDescriptor, layerName.c_str());
966*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
967*89c4ff92SAndroid Build Coastguard Worker 
968*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(0));
969*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(inputTensorInfo1);
970*89c4ff92SAndroid Build Coastguard Worker     inputLayer2->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(1));
971*89c4ff92SAndroid Build Coastguard Worker     inputLayer2->GetOutputSlot(0).SetTensorInfo(inputTensorInfo2);
972*89c4ff92SAndroid Build Coastguard Worker     equalLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
973*89c4ff92SAndroid Build Coastguard Worker     equalLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
974*89c4ff92SAndroid Build Coastguard Worker 
975*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
976*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
977*89c4ff92SAndroid Build Coastguard Worker 
978*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {inputTensorInfo1, inputTensorInfo2}, {outputTensorInfo});
979*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
980*89c4ff92SAndroid Build Coastguard Worker }
981*89c4ff92SAndroid Build Coastguard Worker 
SerializeElementwiseBinaryTest(armnn::BinaryOperation binaryOperation)982*89c4ff92SAndroid Build Coastguard Worker void SerializeElementwiseBinaryTest(armnn::BinaryOperation binaryOperation)
983*89c4ff92SAndroid Build Coastguard Worker {
984*89c4ff92SAndroid Build Coastguard Worker     auto layerName = GetBinaryOperationAsCString(binaryOperation);
985*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo tensorInfo({ 1, 5, 2, 3 }, armnn::DataType::Float32);
986*89c4ff92SAndroid Build Coastguard Worker     armnn::ElementwiseBinaryDescriptor descriptor(binaryOperation);
987*89c4ff92SAndroid Build Coastguard Worker 
988*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
989*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
990*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
991*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const elementwiseBinaryLayer = network->AddElementwiseBinaryLayer(descriptor,
992*89c4ff92SAndroid Build Coastguard Worker                                                                                                 layerName);
993*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
994*89c4ff92SAndroid Build Coastguard Worker 
995*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(elementwiseBinaryLayer->GetInputSlot(0));
996*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(elementwiseBinaryLayer->GetInputSlot(1));
997*89c4ff92SAndroid Build Coastguard Worker     elementwiseBinaryLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
998*89c4ff92SAndroid Build Coastguard Worker 
999*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(tensorInfo);
1000*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(tensorInfo);
1001*89c4ff92SAndroid Build Coastguard Worker     elementwiseBinaryLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
1002*89c4ff92SAndroid Build Coastguard Worker 
1003*89c4ff92SAndroid Build Coastguard Worker     std::string serializedNetwork = SerializeNetwork(*network);
1004*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(serializedNetwork);
1005*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1006*89c4ff92SAndroid Build Coastguard Worker 
1007*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::ElementwiseBinaryDescriptor>
1008*89c4ff92SAndroid Build Coastguard Worker             verifier(layerName, { tensorInfo, tensorInfo }, { tensorInfo }, descriptor);
1009*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1010*89c4ff92SAndroid Build Coastguard Worker }
1011*89c4ff92SAndroid Build Coastguard Worker 
1012*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeElementwiseBinary")
1013*89c4ff92SAndroid Build Coastguard Worker {
1014*89c4ff92SAndroid Build Coastguard Worker     using op = armnn::BinaryOperation;
1015*89c4ff92SAndroid Build Coastguard Worker     std::initializer_list<op> allBinaryOperations = {op::Add, op::Div, op::Maximum, op::Minimum, op::Mul, op::Sub};
1016*89c4ff92SAndroid Build Coastguard Worker 
1017*89c4ff92SAndroid Build Coastguard Worker     for (auto binaryOperation : allBinaryOperations)
1018*89c4ff92SAndroid Build Coastguard Worker     {
1019*89c4ff92SAndroid Build Coastguard Worker         SerializeElementwiseBinaryTest(binaryOperation);
1020*89c4ff92SAndroid Build Coastguard Worker     }
1021*89c4ff92SAndroid Build Coastguard Worker }
1022*89c4ff92SAndroid Build Coastguard Worker 
SerializeElementwiseUnaryTest(armnn::UnaryOperation unaryOperation)1023*89c4ff92SAndroid Build Coastguard Worker void SerializeElementwiseUnaryTest(armnn::UnaryOperation unaryOperation)
1024*89c4ff92SAndroid Build Coastguard Worker {
1025*89c4ff92SAndroid Build Coastguard Worker     auto layerName = GetUnaryOperationAsCString(unaryOperation);
1026*89c4ff92SAndroid Build Coastguard Worker 
1027*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape shape{2, 1, 2, 2};
1028*89c4ff92SAndroid Build Coastguard Worker 
1029*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo(shape, armnn::DataType::Float32);
1030*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Float32);
1031*89c4ff92SAndroid Build Coastguard Worker 
1032*89c4ff92SAndroid Build Coastguard Worker     armnn::ElementwiseUnaryDescriptor descriptor(unaryOperation);
1033*89c4ff92SAndroid Build Coastguard Worker 
1034*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1035*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
1036*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const elementwiseUnaryLayer =
1037*89c4ff92SAndroid Build Coastguard Worker                                 network->AddElementwiseUnaryLayer(descriptor, layerName);
1038*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1039*89c4ff92SAndroid Build Coastguard Worker 
1040*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(elementwiseUnaryLayer->GetInputSlot(0));
1041*89c4ff92SAndroid Build Coastguard Worker     elementwiseUnaryLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1042*89c4ff92SAndroid Build Coastguard Worker 
1043*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
1044*89c4ff92SAndroid Build Coastguard Worker     elementwiseUnaryLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1045*89c4ff92SAndroid Build Coastguard Worker 
1046*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1047*89c4ff92SAndroid Build Coastguard Worker 
1048*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1049*89c4ff92SAndroid Build Coastguard Worker 
1050*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::ElementwiseUnaryDescriptor>
1051*89c4ff92SAndroid Build Coastguard Worker         verifier(layerName, { inputInfo }, { outputInfo }, descriptor);
1052*89c4ff92SAndroid Build Coastguard Worker 
1053*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1054*89c4ff92SAndroid Build Coastguard Worker }
1055*89c4ff92SAndroid Build Coastguard Worker 
1056*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeElementwiseUnary")
1057*89c4ff92SAndroid Build Coastguard Worker {
1058*89c4ff92SAndroid Build Coastguard Worker     using op = armnn::UnaryOperation;
1059*89c4ff92SAndroid Build Coastguard Worker     std::initializer_list<op> allUnaryOperations = {op::Abs, op::Ceil, op::Exp, op::Sqrt, op::Rsqrt, op::Neg,
1060*89c4ff92SAndroid Build Coastguard Worker                                                     op::LogicalNot, op::Log, op::Sin};
1061*89c4ff92SAndroid Build Coastguard Worker 
1062*89c4ff92SAndroid Build Coastguard Worker     for (auto unaryOperation : allUnaryOperations)
1063*89c4ff92SAndroid Build Coastguard Worker     {
1064*89c4ff92SAndroid Build Coastguard Worker         SerializeElementwiseUnaryTest(unaryOperation);
1065*89c4ff92SAndroid Build Coastguard Worker     }
1066*89c4ff92SAndroid Build Coastguard Worker }
1067*89c4ff92SAndroid Build Coastguard Worker 
1068*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeFill")
1069*89c4ff92SAndroid Build Coastguard Worker {
1070*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("fill");
1071*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({4}, armnn::DataType::Signed32);
1072*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({1, 3, 3, 1}, armnn::DataType::Float32);
1073*89c4ff92SAndroid Build Coastguard Worker 
1074*89c4ff92SAndroid Build Coastguard Worker     armnn::FillDescriptor descriptor(1.0f);
1075*89c4ff92SAndroid Build Coastguard Worker 
1076*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1077*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
1078*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const fillLayer = network->AddFillLayer(descriptor, layerName.c_str());
1079*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1080*89c4ff92SAndroid Build Coastguard Worker 
1081*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(fillLayer->GetInputSlot(0));
1082*89c4ff92SAndroid Build Coastguard Worker     fillLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1083*89c4ff92SAndroid Build Coastguard Worker 
1084*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
1085*89c4ff92SAndroid Build Coastguard Worker     fillLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1086*89c4ff92SAndroid Build Coastguard Worker 
1087*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1088*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1089*89c4ff92SAndroid Build Coastguard Worker 
1090*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::FillDescriptor> verifier(layerName, {inputInfo}, {outputInfo}, descriptor);
1091*89c4ff92SAndroid Build Coastguard Worker 
1092*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1093*89c4ff92SAndroid Build Coastguard Worker }
1094*89c4ff92SAndroid Build Coastguard Worker 
1095*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeFloor")
1096*89c4ff92SAndroid Build Coastguard Worker {
1097*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("floor");
1098*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({4,4}, armnn::DataType::Float32);
1099*89c4ff92SAndroid Build Coastguard Worker 
1100*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1101*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
1102*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const floorLayer = network->AddFloorLayer(layerName.c_str());
1103*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1104*89c4ff92SAndroid Build Coastguard Worker 
1105*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(floorLayer->GetInputSlot(0));
1106*89c4ff92SAndroid Build Coastguard Worker     floorLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1107*89c4ff92SAndroid Build Coastguard Worker 
1108*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(info);
1109*89c4ff92SAndroid Build Coastguard Worker     floorLayer->GetOutputSlot(0).SetTensorInfo(info);
1110*89c4ff92SAndroid Build Coastguard Worker 
1111*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1112*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1113*89c4ff92SAndroid Build Coastguard Worker 
1114*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {info}, {info});
1115*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1116*89c4ff92SAndroid Build Coastguard Worker }
1117*89c4ff92SAndroid Build Coastguard Worker 
1118*89c4ff92SAndroid Build Coastguard Worker using FullyConnectedDescriptor = armnn::FullyConnectedDescriptor;
1119*89c4ff92SAndroid Build Coastguard Worker class FullyConnectedLayerVerifier : public LayerVerifierBaseWithDescriptor<FullyConnectedDescriptor>
1120*89c4ff92SAndroid Build Coastguard Worker {
1121*89c4ff92SAndroid Build Coastguard Worker public:
FullyConnectedLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const FullyConnectedDescriptor & descriptor)1122*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedLayerVerifier(const std::string& layerName,
1123*89c4ff92SAndroid Build Coastguard Worker                         const std::vector<armnn::TensorInfo>& inputInfos,
1124*89c4ff92SAndroid Build Coastguard Worker                         const std::vector<armnn::TensorInfo>& outputInfos,
1125*89c4ff92SAndroid Build Coastguard Worker                         const FullyConnectedDescriptor& descriptor)
1126*89c4ff92SAndroid Build Coastguard Worker         : LayerVerifierBaseWithDescriptor<FullyConnectedDescriptor>(layerName, inputInfos, outputInfos, descriptor) {}
1127*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)1128*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
1129*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
1130*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
1131*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
1132*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override
1133*89c4ff92SAndroid Build Coastguard Worker     {
1134*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(constants, id);
1135*89c4ff92SAndroid Build Coastguard Worker         switch (layer->GetType())
1136*89c4ff92SAndroid Build Coastguard Worker         {
1137*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Input: break;
1138*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Output: break;
1139*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Constant: break;
1140*89c4ff92SAndroid Build Coastguard Worker             default:
1141*89c4ff92SAndroid Build Coastguard Worker             {
1142*89c4ff92SAndroid Build Coastguard Worker                 VerifyNameAndConnections(layer, name);
1143*89c4ff92SAndroid Build Coastguard Worker                 const FullyConnectedDescriptor& layerDescriptor =
1144*89c4ff92SAndroid Build Coastguard Worker                         static_cast<const FullyConnectedDescriptor&>(descriptor);
1145*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_ConstantWeights == m_Descriptor.m_ConstantWeights);
1146*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_BiasEnabled == m_Descriptor.m_BiasEnabled);
1147*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_TransposeWeightMatrix == m_Descriptor.m_TransposeWeightMatrix);
1148*89c4ff92SAndroid Build Coastguard Worker             }
1149*89c4ff92SAndroid Build Coastguard Worker         }
1150*89c4ff92SAndroid Build Coastguard Worker     }
1151*89c4ff92SAndroid Build Coastguard Worker };
1152*89c4ff92SAndroid Build Coastguard Worker 
1153*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeFullyConnected")
1154*89c4ff92SAndroid Build Coastguard Worker {
1155*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("fullyConnected");
1156*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 2, 5, 1, 1 }, armnn::DataType::Float32);
1157*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 2, 3 }, armnn::DataType::Float32);
1158*89c4ff92SAndroid Build Coastguard Worker 
1159*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo weightsInfo({ 5, 3 }, armnn::DataType::Float32, 0.0f, 0, true);
1160*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo biasesInfo ({ 3 }, armnn::DataType::Float32, 0.0f, 0, true);
1161*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
1162*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData  = GenerateRandomData<float>(biasesInfo.GetNumElements());
1163*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightsInfo, weightsData);
1164*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasesInfo, biasesData);
1165*89c4ff92SAndroid Build Coastguard Worker 
1166*89c4ff92SAndroid Build Coastguard Worker     armnn::FullyConnectedDescriptor descriptor;
1167*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
1168*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_TransposeWeightMatrix = false;
1169*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ConstantWeights = true;
1170*89c4ff92SAndroid Build Coastguard Worker 
1171*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1172*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
1173*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsInputLayer = network->AddInputLayer(1);
1174*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const biasInputLayer = network->AddInputLayer(2);
1175*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const fullyConnectedLayer =
1176*89c4ff92SAndroid Build Coastguard Worker             network->AddFullyConnectedLayer(descriptor,
1177*89c4ff92SAndroid Build Coastguard Worker                                             layerName.c_str());
1178*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1179*89c4ff92SAndroid Build Coastguard Worker 
1180*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
1181*89c4ff92SAndroid Build Coastguard Worker     weightsInputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
1182*89c4ff92SAndroid Build Coastguard Worker     biasInputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(2));
1183*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1184*89c4ff92SAndroid Build Coastguard Worker 
1185*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
1186*89c4ff92SAndroid Build Coastguard Worker     weightsInputLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
1187*89c4ff92SAndroid Build Coastguard Worker     biasInputLayer->GetOutputSlot(0).SetTensorInfo(biasesInfo);
1188*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1189*89c4ff92SAndroid Build Coastguard Worker 
1190*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1191*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1192*89c4ff92SAndroid Build Coastguard Worker 
1193*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedLayerVerifier verifier(layerName, {inputInfo, weightsInfo, biasesInfo}, {outputInfo}, descriptor);
1194*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1195*89c4ff92SAndroid Build Coastguard Worker }
1196*89c4ff92SAndroid Build Coastguard Worker 
1197*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeFullyConnectedWeightsAndBiasesAsInputs")
1198*89c4ff92SAndroid Build Coastguard Worker {
1199*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("fullyConnected_weights_as_inputs");
1200*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 2, 5, 1, 1 }, armnn::DataType::Float32);
1201*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 2, 3 }, armnn::DataType::Float32);
1202*89c4ff92SAndroid Build Coastguard Worker 
1203*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo weightsInfo({ 5, 3 }, armnn::DataType::Float32);
1204*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo biasesInfo ({ 3 }, armnn::DataType::Float32);
1205*89c4ff92SAndroid Build Coastguard Worker 
1206*89c4ff92SAndroid Build Coastguard Worker     armnn::Optional<armnn::ConstTensor> weights = armnn::EmptyOptional();
1207*89c4ff92SAndroid Build Coastguard Worker     armnn::Optional<armnn::ConstTensor> bias = armnn::EmptyOptional();
1208*89c4ff92SAndroid Build Coastguard Worker 
1209*89c4ff92SAndroid Build Coastguard Worker     armnn::FullyConnectedDescriptor descriptor;
1210*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
1211*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_TransposeWeightMatrix = false;
1212*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ConstantWeights = false;
1213*89c4ff92SAndroid Build Coastguard Worker 
1214*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1215*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
1216*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsInputLayer = network->AddInputLayer(1);
1217*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const biasInputLayer = network->AddInputLayer(2);
1218*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const fullyConnectedLayer =
1219*89c4ff92SAndroid Build Coastguard Worker         network->AddFullyConnectedLayer(descriptor,
1220*89c4ff92SAndroid Build Coastguard Worker                                         layerName.c_str());
1221*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1222*89c4ff92SAndroid Build Coastguard Worker 
1223*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
1224*89c4ff92SAndroid Build Coastguard Worker     weightsInputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
1225*89c4ff92SAndroid Build Coastguard Worker     biasInputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(2));
1226*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1227*89c4ff92SAndroid Build Coastguard Worker 
1228*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
1229*89c4ff92SAndroid Build Coastguard Worker     weightsInputLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
1230*89c4ff92SAndroid Build Coastguard Worker     biasInputLayer->GetOutputSlot(0).SetTensorInfo(biasesInfo);
1231*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1232*89c4ff92SAndroid Build Coastguard Worker 
1233*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1234*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1235*89c4ff92SAndroid Build Coastguard Worker 
1236*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::ConstTensor> constants {};
1237*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptorAndConstants<armnn::FullyConnectedDescriptor> verifier(
1238*89c4ff92SAndroid Build Coastguard Worker         layerName, {inputInfo, weightsInfo, biasesInfo}, {outputInfo}, descriptor, constants);
1239*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1240*89c4ff92SAndroid Build Coastguard Worker }
1241*89c4ff92SAndroid Build Coastguard Worker 
1242*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeFullyConnectedWeightsAndBiasesAsConstantLayers")
1243*89c4ff92SAndroid Build Coastguard Worker {
1244*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("fullyConnected_weights_as_inputs");
1245*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 2, 5, 1, 1 }, armnn::DataType::Float32);
1246*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 2, 3 }, armnn::DataType::Float32);
1247*89c4ff92SAndroid Build Coastguard Worker 
1248*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo weightsInfo({ 5, 3 }, armnn::DataType::Float32, 0.0f, 0, true);
1249*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo biasesInfo ({ 3 }, armnn::DataType::Float32, 0.0f, 0, true);
1250*89c4ff92SAndroid Build Coastguard Worker 
1251*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
1252*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData  = GenerateRandomData<float>(biasesInfo.GetNumElements());
1253*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightsInfo, weightsData);
1254*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasesInfo, biasesData);
1255*89c4ff92SAndroid Build Coastguard Worker 
1256*89c4ff92SAndroid Build Coastguard Worker     armnn::FullyConnectedDescriptor descriptor;
1257*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
1258*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_TransposeWeightMatrix = false;
1259*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ConstantWeights = true;
1260*89c4ff92SAndroid Build Coastguard Worker 
1261*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1262*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
1263*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const weightsLayer = network->AddConstantLayer(weights, "Weights");
1264*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const biasesLayer = network->AddConstantLayer(biases, "Biases");
1265*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor,layerName.c_str());
1266*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1267*89c4ff92SAndroid Build Coastguard Worker 
1268*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
1269*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
1270*89c4ff92SAndroid Build Coastguard Worker     biasesLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(2));
1271*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1272*89c4ff92SAndroid Build Coastguard Worker 
1273*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
1274*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
1275*89c4ff92SAndroid Build Coastguard Worker     biasesLayer->GetOutputSlot(0).SetTensorInfo(biasesInfo);
1276*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1277*89c4ff92SAndroid Build Coastguard Worker 
1278*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1279*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1280*89c4ff92SAndroid Build Coastguard Worker 
1281*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedLayerVerifier verifier(layerName, {inputInfo, weightsInfo, biasesInfo}, {outputInfo}, descriptor);
1282*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1283*89c4ff92SAndroid Build Coastguard Worker }
1284*89c4ff92SAndroid Build Coastguard Worker 
1285*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeGather")
1286*89c4ff92SAndroid Build Coastguard Worker {
1287*89c4ff92SAndroid Build Coastguard Worker     using GatherDescriptor = armnn::GatherDescriptor;
1288*89c4ff92SAndroid Build Coastguard Worker     class GatherLayerVerifier : public LayerVerifierBaseWithDescriptor<GatherDescriptor>
1289*89c4ff92SAndroid Build Coastguard Worker     {
1290*89c4ff92SAndroid Build Coastguard Worker     public:
GatherLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const GatherDescriptor & descriptor)1291*89c4ff92SAndroid Build Coastguard Worker         GatherLayerVerifier(const std::string& layerName,
1292*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<armnn::TensorInfo>& inputInfos,
1293*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<armnn::TensorInfo>& outputInfos,
1294*89c4ff92SAndroid Build Coastguard Worker                             const GatherDescriptor& descriptor)
1295*89c4ff92SAndroid Build Coastguard Worker             : LayerVerifierBaseWithDescriptor<GatherDescriptor>(layerName, inputInfos, outputInfos, descriptor) {}
1296*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)1297*89c4ff92SAndroid Build Coastguard Worker         void ExecuteStrategy(const armnn::IConnectableLayer* layer,
1298*89c4ff92SAndroid Build Coastguard Worker                              const armnn::BaseDescriptor& descriptor,
1299*89c4ff92SAndroid Build Coastguard Worker                              const std::vector<armnn::ConstTensor>& constants,
1300*89c4ff92SAndroid Build Coastguard Worker                              const char* name,
1301*89c4ff92SAndroid Build Coastguard Worker                              const armnn::LayerBindingId id = 0) override
1302*89c4ff92SAndroid Build Coastguard Worker         {
1303*89c4ff92SAndroid Build Coastguard Worker             armnn::IgnoreUnused(constants, id);
1304*89c4ff92SAndroid Build Coastguard Worker             switch (layer->GetType())
1305*89c4ff92SAndroid Build Coastguard Worker             {
1306*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Input: break;
1307*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Output: break;
1308*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Constant: break;
1309*89c4ff92SAndroid Build Coastguard Worker                 default:
1310*89c4ff92SAndroid Build Coastguard Worker                 {
1311*89c4ff92SAndroid Build Coastguard Worker                     VerifyNameAndConnections(layer, name);
1312*89c4ff92SAndroid Build Coastguard Worker                     const GatherDescriptor& layerDescriptor = static_cast<const GatherDescriptor&>(descriptor);
1313*89c4ff92SAndroid Build Coastguard Worker                     CHECK(layerDescriptor.m_Axis == m_Descriptor.m_Axis);
1314*89c4ff92SAndroid Build Coastguard Worker                 }
1315*89c4ff92SAndroid Build Coastguard Worker             }
1316*89c4ff92SAndroid Build Coastguard Worker         }
1317*89c4ff92SAndroid Build Coastguard Worker     };
1318*89c4ff92SAndroid Build Coastguard Worker 
1319*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("gather");
1320*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::QAsymmU8);
1321*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo({ 3 }, armnn::DataType::QAsymmU8);
1322*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32, 0.0f, 0, true);
1323*89c4ff92SAndroid Build Coastguard Worker     GatherDescriptor descriptor;
1324*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Axis = 1;
1325*89c4ff92SAndroid Build Coastguard Worker 
1326*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetQuantizationScale(1.0f);
1327*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetQuantizationOffset(0);
1328*89c4ff92SAndroid Build Coastguard Worker     outputInfo.SetQuantizationScale(1.0f);
1329*89c4ff92SAndroid Build Coastguard Worker     outputInfo.SetQuantizationOffset(0);
1330*89c4ff92SAndroid Build Coastguard Worker 
1331*89c4ff92SAndroid Build Coastguard Worker     const std::vector<int32_t>& indicesData = {7, 6, 5};
1332*89c4ff92SAndroid Build Coastguard Worker 
1333*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1334*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
1335*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer *const constantLayer =
1336*89c4ff92SAndroid Build Coastguard Worker             network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
1337*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer *const gatherLayer = network->AddGatherLayer(descriptor, layerName.c_str());
1338*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
1339*89c4ff92SAndroid Build Coastguard Worker 
1340*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(0));
1341*89c4ff92SAndroid Build Coastguard Worker     constantLayer->GetOutputSlot(0).Connect(gatherLayer->GetInputSlot(1));
1342*89c4ff92SAndroid Build Coastguard Worker     gatherLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1343*89c4ff92SAndroid Build Coastguard Worker 
1344*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(paramsInfo);
1345*89c4ff92SAndroid Build Coastguard Worker     constantLayer->GetOutputSlot(0).SetTensorInfo(indicesInfo);
1346*89c4ff92SAndroid Build Coastguard Worker     gatherLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1347*89c4ff92SAndroid Build Coastguard Worker 
1348*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1349*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1350*89c4ff92SAndroid Build Coastguard Worker 
1351*89c4ff92SAndroid Build Coastguard Worker     GatherLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo}, descriptor);
1352*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1353*89c4ff92SAndroid Build Coastguard Worker }
1354*89c4ff92SAndroid Build Coastguard Worker 
1355*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeGatherNd")
1356*89c4ff92SAndroid Build Coastguard Worker {
1357*89c4ff92SAndroid Build Coastguard Worker     class GatherNdLayerVerifier : public LayerVerifierBase
1358*89c4ff92SAndroid Build Coastguard Worker     {
1359*89c4ff92SAndroid Build Coastguard Worker     public:
GatherNdLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos)1360*89c4ff92SAndroid Build Coastguard Worker         GatherNdLayerVerifier(const std::string& layerName,
1361*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<armnn::TensorInfo>& inputInfos,
1362*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<armnn::TensorInfo>& outputInfos)
1363*89c4ff92SAndroid Build Coastguard Worker                 : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
1364*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor &,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)1365*89c4ff92SAndroid Build Coastguard Worker         void ExecuteStrategy(const armnn::IConnectableLayer* layer,
1366*89c4ff92SAndroid Build Coastguard Worker                              const armnn::BaseDescriptor&,
1367*89c4ff92SAndroid Build Coastguard Worker                              const std::vector<armnn::ConstTensor>& constants,
1368*89c4ff92SAndroid Build Coastguard Worker                              const char* name,
1369*89c4ff92SAndroid Build Coastguard Worker                              const armnn::LayerBindingId id = 0) override
1370*89c4ff92SAndroid Build Coastguard Worker         {
1371*89c4ff92SAndroid Build Coastguard Worker             armnn::IgnoreUnused(constants, id);
1372*89c4ff92SAndroid Build Coastguard Worker             switch (layer->GetType())
1373*89c4ff92SAndroid Build Coastguard Worker             {
1374*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Input:
1375*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Output:
1376*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Constant:
1377*89c4ff92SAndroid Build Coastguard Worker                     break;
1378*89c4ff92SAndroid Build Coastguard Worker                 default:
1379*89c4ff92SAndroid Build Coastguard Worker                 {
1380*89c4ff92SAndroid Build Coastguard Worker                     VerifyNameAndConnections(layer, name);
1381*89c4ff92SAndroid Build Coastguard Worker                 }
1382*89c4ff92SAndroid Build Coastguard Worker             }
1383*89c4ff92SAndroid Build Coastguard Worker         }
1384*89c4ff92SAndroid Build Coastguard Worker     };
1385*89c4ff92SAndroid Build Coastguard Worker 
1386*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("gatherNd");
1387*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo paramsInfo({ 6, 3 }, armnn::DataType::QAsymmU8);
1388*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo({ 3, 3 }, armnn::DataType::QAsymmU8);
1389*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo indicesInfo({ 3, 1 }, armnn::DataType::Signed32, 0.0f, 0, true);
1390*89c4ff92SAndroid Build Coastguard Worker 
1391*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetQuantizationScale(1.0f);
1392*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.SetQuantizationOffset(0);
1393*89c4ff92SAndroid Build Coastguard Worker     outputInfo.SetQuantizationScale(1.0f);
1394*89c4ff92SAndroid Build Coastguard Worker     outputInfo.SetQuantizationOffset(0);
1395*89c4ff92SAndroid Build Coastguard Worker 
1396*89c4ff92SAndroid Build Coastguard Worker     const std::vector<int32_t>& indicesData = {5, 1, 0};
1397*89c4ff92SAndroid Build Coastguard Worker 
1398*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1399*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer *const inputLayer = network->AddInputLayer(0);
1400*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer *const constantLayer =
1401*89c4ff92SAndroid Build Coastguard Worker                    network->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
1402*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer *const gatherNdLayer = network->AddGatherNdLayer(layerName.c_str());
1403*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer *const outputLayer = network->AddOutputLayer(0);
1404*89c4ff92SAndroid Build Coastguard Worker 
1405*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(gatherNdLayer->GetInputSlot(0));
1406*89c4ff92SAndroid Build Coastguard Worker     constantLayer->GetOutputSlot(0).Connect(gatherNdLayer->GetInputSlot(1));
1407*89c4ff92SAndroid Build Coastguard Worker     gatherNdLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1408*89c4ff92SAndroid Build Coastguard Worker 
1409*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(paramsInfo);
1410*89c4ff92SAndroid Build Coastguard Worker     constantLayer->GetOutputSlot(0).SetTensorInfo(indicesInfo);
1411*89c4ff92SAndroid Build Coastguard Worker     gatherNdLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1412*89c4ff92SAndroid Build Coastguard Worker 
1413*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1414*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1415*89c4ff92SAndroid Build Coastguard Worker 
1416*89c4ff92SAndroid Build Coastguard Worker     GatherNdLayerVerifier verifier(layerName, {paramsInfo, indicesInfo}, {outputInfo});
1417*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1418*89c4ff92SAndroid Build Coastguard Worker }
1419*89c4ff92SAndroid Build Coastguard Worker 
1420*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeComparisonGreater")
1421*89c4ff92SAndroid Build Coastguard Worker {
1422*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("greater");
1423*89c4ff92SAndroid Build Coastguard Worker 
1424*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape shape{2, 1, 2, 4};
1425*89c4ff92SAndroid Build Coastguard Worker 
1426*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo(shape, armnn::DataType::Float32);
1427*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean);
1428*89c4ff92SAndroid Build Coastguard Worker 
1429*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1430*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
1431*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
1432*89c4ff92SAndroid Build Coastguard Worker     armnn::ComparisonDescriptor greaterDescriptor(armnn::ComparisonOperation::Greater);
1433*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const equalLayer = network->AddComparisonLayer(greaterDescriptor, layerName.c_str());
1434*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1435*89c4ff92SAndroid Build Coastguard Worker 
1436*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(0));
1437*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(1));
1438*89c4ff92SAndroid Build Coastguard Worker     equalLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1439*89c4ff92SAndroid Build Coastguard Worker 
1440*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(inputInfo);
1441*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(inputInfo);
1442*89c4ff92SAndroid Build Coastguard Worker     equalLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1443*89c4ff92SAndroid Build Coastguard Worker 
1444*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1445*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1446*89c4ff92SAndroid Build Coastguard Worker 
1447*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, { inputInfo, inputInfo }, { outputInfo });
1448*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1449*89c4ff92SAndroid Build Coastguard Worker }
1450*89c4ff92SAndroid Build Coastguard Worker 
1451*89c4ff92SAndroid Build Coastguard Worker 
1452*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeInstanceNormalization")
1453*89c4ff92SAndroid Build Coastguard Worker {
1454*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("instanceNormalization");
1455*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 2, 1, 5 }, armnn::DataType::Float32);
1456*89c4ff92SAndroid Build Coastguard Worker 
1457*89c4ff92SAndroid Build Coastguard Worker     armnn::InstanceNormalizationDescriptor descriptor;
1458*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Gamma      = 1.1f;
1459*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Beta       = 0.1f;
1460*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Eps        = 0.0001f;
1461*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout = armnn::DataLayout::NHWC;
1462*89c4ff92SAndroid Build Coastguard Worker 
1463*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1464*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer        = network->AddInputLayer(0);
1465*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const instanceNormLayer =
1466*89c4ff92SAndroid Build Coastguard Worker         network->AddInstanceNormalizationLayer(descriptor, layerName.c_str());
1467*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer       = network->AddOutputLayer(0);
1468*89c4ff92SAndroid Build Coastguard Worker 
1469*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(instanceNormLayer->GetInputSlot(0));
1470*89c4ff92SAndroid Build Coastguard Worker     instanceNormLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1471*89c4ff92SAndroid Build Coastguard Worker 
1472*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(info);
1473*89c4ff92SAndroid Build Coastguard Worker     instanceNormLayer->GetOutputSlot(0).SetTensorInfo(info);
1474*89c4ff92SAndroid Build Coastguard Worker 
1475*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1476*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1477*89c4ff92SAndroid Build Coastguard Worker 
1478*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::InstanceNormalizationDescriptor> verifier(
1479*89c4ff92SAndroid Build Coastguard Worker             layerName, {info}, {info}, descriptor);
1480*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1481*89c4ff92SAndroid Build Coastguard Worker }
1482*89c4ff92SAndroid Build Coastguard Worker 
1483*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeL2Normalization")
1484*89c4ff92SAndroid Build Coastguard Worker {
1485*89c4ff92SAndroid Build Coastguard Worker     const std::string l2NormLayerName("l2Normalization");
1486*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({1, 2, 1, 5}, armnn::DataType::Float32);
1487*89c4ff92SAndroid Build Coastguard Worker 
1488*89c4ff92SAndroid Build Coastguard Worker     armnn::L2NormalizationDescriptor desc;
1489*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NCHW;
1490*89c4ff92SAndroid Build Coastguard Worker     desc.m_Eps = 0.0001f;
1491*89c4ff92SAndroid Build Coastguard Worker 
1492*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1493*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
1494*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const l2NormLayer = network->AddL2NormalizationLayer(desc, l2NormLayerName.c_str());
1495*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1496*89c4ff92SAndroid Build Coastguard Worker 
1497*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(l2NormLayer->GetInputSlot(0));
1498*89c4ff92SAndroid Build Coastguard Worker     l2NormLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1499*89c4ff92SAndroid Build Coastguard Worker 
1500*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(info);
1501*89c4ff92SAndroid Build Coastguard Worker     l2NormLayer->GetOutputSlot(0).SetTensorInfo(info);
1502*89c4ff92SAndroid Build Coastguard Worker 
1503*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1504*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1505*89c4ff92SAndroid Build Coastguard Worker 
1506*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::L2NormalizationDescriptor> verifier(
1507*89c4ff92SAndroid Build Coastguard Worker             l2NormLayerName, {info}, {info}, desc);
1508*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1509*89c4ff92SAndroid Build Coastguard Worker }
1510*89c4ff92SAndroid Build Coastguard Worker 
1511*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("EnsureL2NormalizationBackwardCompatibility")
1512*89c4ff92SAndroid Build Coastguard Worker {
1513*89c4ff92SAndroid Build Coastguard Worker     // The hex data below is a flat buffer containing a simple network with one input
1514*89c4ff92SAndroid Build Coastguard Worker     // a L2Normalization layer and an output layer with dimensions as per the tensor infos below.
1515*89c4ff92SAndroid Build Coastguard Worker     //
1516*89c4ff92SAndroid Build Coastguard Worker     // This test verifies that we can still read back these old style
1517*89c4ff92SAndroid Build Coastguard Worker     // models without the normalization epsilon value.
1518*89c4ff92SAndroid Build Coastguard Worker     const std::vector<uint8_t> l2NormalizationModel =
1519*89c4ff92SAndroid Build Coastguard Worker     {
1520*89c4ff92SAndroid Build Coastguard Worker         0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x10, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x0A, 0x00,
1521*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x1C, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
1522*89c4ff92SAndroid Build Coastguard Worker         0x3C, 0x01, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
1523*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xE8, 0xFE, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x0B,
1524*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x00, 0x00, 0xD6, 0xFE, 0xFF, 0xFF, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00,
1525*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9E, 0xFF, 0xFF, 0xFF, 0x02, 0x00, 0x00, 0x00,
1526*89c4ff92SAndroid Build Coastguard Worker         0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00,
1527*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1528*89c4ff92SAndroid Build Coastguard Worker         0x4C, 0xFF, 0xFF, 0xFF, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x44, 0xFF, 0xFF, 0xFF, 0x00, 0x00,
1529*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x20, 0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00,
1530*89c4ff92SAndroid Build Coastguard Worker         0x20, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x06, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00,
1531*89c4ff92SAndroid Build Coastguard Worker         0x0E, 0x00, 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0E, 0x00, 0x00, 0x00,
1532*89c4ff92SAndroid Build Coastguard Worker         0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x1F, 0x00, 0x00, 0x00, 0x1C, 0x00, 0x00, 0x00, 0x20, 0x00,
1533*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x6C, 0x32, 0x4E, 0x6F, 0x72, 0x6D, 0x61, 0x6C, 0x69, 0x7A, 0x61, 0x74,
1534*89c4ff92SAndroid Build Coastguard Worker         0x69, 0x6F, 0x6E, 0x00, 0x01, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0C, 0x00,
1535*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
1536*89c4ff92SAndroid Build Coastguard Worker         0x52, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00,
1537*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00,
1538*89c4ff92SAndroid Build Coastguard Worker         0x08, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1539*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x07, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09,
1540*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x00, 0x00, 0xF6, 0xFF, 0xFF, 0xFF, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x0A, 0x00,
1541*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0E, 0x00, 0x14, 0x00, 0x00, 0x00,
1542*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x10, 0x00, 0x0E, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00,
1543*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1544*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0A, 0x00, 0x00, 0x00,
1545*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x10, 0x00, 0x08, 0x00,
1546*89c4ff92SAndroid Build Coastguard Worker         0x07, 0x00, 0x0C, 0x00, 0x0A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
1547*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
1548*89c4ff92SAndroid Build Coastguard Worker         0x05, 0x00, 0x00, 0x00, 0x00
1549*89c4ff92SAndroid Build Coastguard Worker     };
1550*89c4ff92SAndroid Build Coastguard Worker 
1551*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork =
1552*89c4ff92SAndroid Build Coastguard Worker         DeserializeNetwork(std::string(l2NormalizationModel.begin(), l2NormalizationModel.end()));
1553*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1554*89c4ff92SAndroid Build Coastguard Worker 
1555*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("l2Normalization");
1556*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo = armnn::TensorInfo({1, 2, 1, 5}, armnn::DataType::Float32, 0.0f, 0);
1557*89c4ff92SAndroid Build Coastguard Worker 
1558*89c4ff92SAndroid Build Coastguard Worker     armnn::L2NormalizationDescriptor desc;
1559*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NCHW;
1560*89c4ff92SAndroid Build Coastguard Worker     // Since this variable does not exist in the l2NormalizationModel dump, the default value will be loaded
1561*89c4ff92SAndroid Build Coastguard Worker     desc.m_Eps = 1e-12f;
1562*89c4ff92SAndroid Build Coastguard Worker 
1563*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::L2NormalizationDescriptor> verifier(
1564*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {inputInfo}, desc);
1565*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1566*89c4ff92SAndroid Build Coastguard Worker }
1567*89c4ff92SAndroid Build Coastguard Worker 
1568*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeLogicalBinary")
1569*89c4ff92SAndroid Build Coastguard Worker {
1570*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("logicalBinaryAnd");
1571*89c4ff92SAndroid Build Coastguard Worker 
1572*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape shape{2, 1, 2, 2};
1573*89c4ff92SAndroid Build Coastguard Worker 
1574*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo(shape, armnn::DataType::Boolean);
1575*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean);
1576*89c4ff92SAndroid Build Coastguard Worker 
1577*89c4ff92SAndroid Build Coastguard Worker     armnn::LogicalBinaryDescriptor descriptor(armnn::LogicalBinaryOperation::LogicalAnd);
1578*89c4ff92SAndroid Build Coastguard Worker 
1579*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1580*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0        = network->AddInputLayer(0);
1581*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1        = network->AddInputLayer(1);
1582*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const logicalBinaryLayer = network->AddLogicalBinaryLayer(descriptor, layerName.c_str());
1583*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer        = network->AddOutputLayer(0);
1584*89c4ff92SAndroid Build Coastguard Worker 
1585*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(logicalBinaryLayer->GetInputSlot(0));
1586*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(logicalBinaryLayer->GetInputSlot(1));
1587*89c4ff92SAndroid Build Coastguard Worker     logicalBinaryLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1588*89c4ff92SAndroid Build Coastguard Worker 
1589*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(inputInfo);
1590*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(inputInfo);
1591*89c4ff92SAndroid Build Coastguard Worker     logicalBinaryLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1592*89c4ff92SAndroid Build Coastguard Worker 
1593*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1594*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1595*89c4ff92SAndroid Build Coastguard Worker 
1596*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::LogicalBinaryDescriptor> verifier(
1597*89c4ff92SAndroid Build Coastguard Worker             layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor);
1598*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1599*89c4ff92SAndroid Build Coastguard Worker }
1600*89c4ff92SAndroid Build Coastguard Worker 
1601*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeLogSoftmax")
1602*89c4ff92SAndroid Build Coastguard Worker {
1603*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("log_softmax");
1604*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({1, 10}, armnn::DataType::Float32);
1605*89c4ff92SAndroid Build Coastguard Worker 
1606*89c4ff92SAndroid Build Coastguard Worker     armnn::LogSoftmaxDescriptor descriptor;
1607*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Beta = 1.0f;
1608*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Axis = -1;
1609*89c4ff92SAndroid Build Coastguard Worker 
1610*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1611*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer      = network->AddInputLayer(0);
1612*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const logSoftmaxLayer = network->AddLogSoftmaxLayer(descriptor, layerName.c_str());
1613*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer     = network->AddOutputLayer(0);
1614*89c4ff92SAndroid Build Coastguard Worker 
1615*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(logSoftmaxLayer->GetInputSlot(0));
1616*89c4ff92SAndroid Build Coastguard Worker     logSoftmaxLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1617*89c4ff92SAndroid Build Coastguard Worker 
1618*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(info);
1619*89c4ff92SAndroid Build Coastguard Worker     logSoftmaxLayer->GetOutputSlot(0).SetTensorInfo(info);
1620*89c4ff92SAndroid Build Coastguard Worker 
1621*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1622*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1623*89c4ff92SAndroid Build Coastguard Worker 
1624*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::LogSoftmaxDescriptor> verifier(layerName, {info}, {info}, descriptor);
1625*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1626*89c4ff92SAndroid Build Coastguard Worker }
1627*89c4ff92SAndroid Build Coastguard Worker 
1628*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeMaximum")
1629*89c4ff92SAndroid Build Coastguard Worker {
1630*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("maximum");
1631*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 2, 2, 3 }, armnn::DataType::Float32);
1632*89c4ff92SAndroid Build Coastguard Worker 
1633*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1634*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
1635*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
1636*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
1637*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const maximumLayer = network->AddMaximumLayer(layerName.c_str());
1638*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
1639*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1640*89c4ff92SAndroid Build Coastguard Worker 
1641*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(maximumLayer->GetInputSlot(0));
1642*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(maximumLayer->GetInputSlot(1));
1643*89c4ff92SAndroid Build Coastguard Worker     maximumLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1644*89c4ff92SAndroid Build Coastguard Worker 
1645*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(info);
1646*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(info);
1647*89c4ff92SAndroid Build Coastguard Worker     maximumLayer->GetOutputSlot(0).SetTensorInfo(info);
1648*89c4ff92SAndroid Build Coastguard Worker 
1649*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1650*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1651*89c4ff92SAndroid Build Coastguard Worker 
1652*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {info, info}, {info});
1653*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1654*89c4ff92SAndroid Build Coastguard Worker }
1655*89c4ff92SAndroid Build Coastguard Worker 
1656*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeMean")
1657*89c4ff92SAndroid Build Coastguard Worker {
1658*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("mean");
1659*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({1, 1, 3, 2}, armnn::DataType::Float32);
1660*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({1, 1, 1, 2}, armnn::DataType::Float32);
1661*89c4ff92SAndroid Build Coastguard Worker 
1662*89c4ff92SAndroid Build Coastguard Worker     armnn::MeanDescriptor descriptor;
1663*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Axis = { 2 };
1664*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_KeepDims = true;
1665*89c4ff92SAndroid Build Coastguard Worker 
1666*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1667*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer   = network->AddInputLayer(0);
1668*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const meanLayer = network->AddMeanLayer(descriptor, layerName.c_str());
1669*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer  = network->AddOutputLayer(0);
1670*89c4ff92SAndroid Build Coastguard Worker 
1671*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(meanLayer->GetInputSlot(0));
1672*89c4ff92SAndroid Build Coastguard Worker     meanLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1673*89c4ff92SAndroid Build Coastguard Worker 
1674*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
1675*89c4ff92SAndroid Build Coastguard Worker     meanLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1676*89c4ff92SAndroid Build Coastguard Worker 
1677*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1678*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1679*89c4ff92SAndroid Build Coastguard Worker 
1680*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::MeanDescriptor> verifier(layerName, {inputInfo}, {outputInfo}, descriptor);
1681*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1682*89c4ff92SAndroid Build Coastguard Worker }
1683*89c4ff92SAndroid Build Coastguard Worker 
1684*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeMerge")
1685*89c4ff92SAndroid Build Coastguard Worker {
1686*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("merge");
1687*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 2, 2, 3 }, armnn::DataType::Float32);
1688*89c4ff92SAndroid Build Coastguard Worker 
1689*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1690*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
1691*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
1692*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const mergeLayer = network->AddMergeLayer(layerName.c_str());
1693*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1694*89c4ff92SAndroid Build Coastguard Worker 
1695*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(0));
1696*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(1));
1697*89c4ff92SAndroid Build Coastguard Worker     mergeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1698*89c4ff92SAndroid Build Coastguard Worker 
1699*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(info);
1700*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(info);
1701*89c4ff92SAndroid Build Coastguard Worker     mergeLayer->GetOutputSlot(0).SetTensorInfo(info);
1702*89c4ff92SAndroid Build Coastguard Worker 
1703*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1704*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1705*89c4ff92SAndroid Build Coastguard Worker 
1706*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {info, info}, {info});
1707*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1708*89c4ff92SAndroid Build Coastguard Worker }
1709*89c4ff92SAndroid Build Coastguard Worker 
1710*89c4ff92SAndroid Build Coastguard Worker class MergerLayerVerifier : public LayerVerifierBaseWithDescriptor<armnn::OriginsDescriptor>
1711*89c4ff92SAndroid Build Coastguard Worker {
1712*89c4ff92SAndroid Build Coastguard Worker public:
MergerLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const armnn::OriginsDescriptor & descriptor)1713*89c4ff92SAndroid Build Coastguard Worker     MergerLayerVerifier(const std::string& layerName,
1714*89c4ff92SAndroid Build Coastguard Worker                         const std::vector<armnn::TensorInfo>& inputInfos,
1715*89c4ff92SAndroid Build Coastguard Worker                         const std::vector<armnn::TensorInfo>& outputInfos,
1716*89c4ff92SAndroid Build Coastguard Worker                         const armnn::OriginsDescriptor& descriptor)
1717*89c4ff92SAndroid Build Coastguard Worker         : LayerVerifierBaseWithDescriptor<armnn::OriginsDescriptor>(layerName, inputInfos, outputInfos, descriptor) {}
1718*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)1719*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
1720*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
1721*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
1722*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
1723*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override
1724*89c4ff92SAndroid Build Coastguard Worker     {
1725*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(descriptor, constants, id);
1726*89c4ff92SAndroid Build Coastguard Worker         switch (layer->GetType())
1727*89c4ff92SAndroid Build Coastguard Worker         {
1728*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Input: break;
1729*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Output: break;
1730*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Merge:
1731*89c4ff92SAndroid Build Coastguard Worker             {
1732*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception("MergerLayer should have translated to ConcatLayer");
1733*89c4ff92SAndroid Build Coastguard Worker                 break;
1734*89c4ff92SAndroid Build Coastguard Worker             }
1735*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Concat:
1736*89c4ff92SAndroid Build Coastguard Worker             {
1737*89c4ff92SAndroid Build Coastguard Worker                 VerifyNameAndConnections(layer, name);
1738*89c4ff92SAndroid Build Coastguard Worker                 const armnn::MergerDescriptor& layerDescriptor =
1739*89c4ff92SAndroid Build Coastguard Worker                         static_cast<const armnn::MergerDescriptor&>(descriptor);
1740*89c4ff92SAndroid Build Coastguard Worker                 VerifyDescriptor(layerDescriptor);
1741*89c4ff92SAndroid Build Coastguard Worker                 break;
1742*89c4ff92SAndroid Build Coastguard Worker             }
1743*89c4ff92SAndroid Build Coastguard Worker             default:
1744*89c4ff92SAndroid Build Coastguard Worker             {
1745*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception("Unexpected layer type in Merge test model");
1746*89c4ff92SAndroid Build Coastguard Worker             }
1747*89c4ff92SAndroid Build Coastguard Worker         }
1748*89c4ff92SAndroid Build Coastguard Worker     }
1749*89c4ff92SAndroid Build Coastguard Worker };
1750*89c4ff92SAndroid Build Coastguard Worker 
1751*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("EnsureMergerLayerBackwardCompatibility")
1752*89c4ff92SAndroid Build Coastguard Worker {
1753*89c4ff92SAndroid Build Coastguard Worker     // The hex data below is a flat buffer containing a simple network with two inputs
1754*89c4ff92SAndroid Build Coastguard Worker     // a merger layer (now deprecated) and an output layer with dimensions as per the tensor infos below.
1755*89c4ff92SAndroid Build Coastguard Worker     //
1756*89c4ff92SAndroid Build Coastguard Worker     // This test verifies that we can still read back these old style
1757*89c4ff92SAndroid Build Coastguard Worker     // models replacing the MergerLayers with ConcatLayers with the same parameters.
1758*89c4ff92SAndroid Build Coastguard Worker     const std::vector<uint8_t> mergerModel =
1759*89c4ff92SAndroid Build Coastguard Worker     {
1760*89c4ff92SAndroid Build Coastguard Worker         0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x10, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x0A, 0x00,
1761*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x1C, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
1762*89c4ff92SAndroid Build Coastguard Worker         0x38, 0x02, 0x00, 0x00, 0x8C, 0x01, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x02, 0x00,
1763*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
1764*89c4ff92SAndroid Build Coastguard Worker         0xF4, 0xFD, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x0B, 0x04, 0x00, 0x00, 0x00, 0x92, 0xFE, 0xFF, 0xFF, 0x04, 0x00,
1765*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x9A, 0xFE, 0xFF, 0xFF, 0x04, 0x00, 0x00, 0x00, 0x7E, 0xFE, 0xFF, 0xFF, 0x03, 0x00, 0x00, 0x00,
1766*89c4ff92SAndroid Build Coastguard Worker         0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00,
1767*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1768*89c4ff92SAndroid Build Coastguard Worker         0xF8, 0xFE, 0xFF, 0xFF, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x48, 0xFE, 0xFF, 0xFF, 0x00, 0x00,
1769*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x1F, 0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00,
1770*89c4ff92SAndroid Build Coastguard Worker         0x68, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x0C, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00,
1771*89c4ff92SAndroid Build Coastguard Worker         0x0C, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
1772*89c4ff92SAndroid Build Coastguard Worker         0x02, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x22, 0xFF, 0xFF, 0xFF, 0x04, 0x00,
1773*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1774*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x3E, 0xFF, 0xFF, 0xFF, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00,
1775*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x36, 0xFF, 0xFF, 0xFF,
1776*89c4ff92SAndroid Build Coastguard Worker         0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x1E, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1C, 0x00,
1777*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x6D, 0x65, 0x72, 0x67, 0x65, 0x72, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
1778*89c4ff92SAndroid Build Coastguard Worker         0x5C, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x34, 0xFF,
1779*89c4ff92SAndroid Build Coastguard Worker         0xFF, 0xFF, 0x04, 0x00, 0x00, 0x00, 0x92, 0xFE, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00,
1780*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00,
1781*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x10, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00,
1782*89c4ff92SAndroid Build Coastguard Worker         0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x00, 0x00,
1783*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0E, 0x00,
1784*89c4ff92SAndroid Build Coastguard Worker         0x07, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00,
1785*89c4ff92SAndroid Build Coastguard Worker         0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0E, 0x00,
1786*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
1787*89c4ff92SAndroid Build Coastguard Worker         0x0E, 0x00, 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0E, 0x00, 0x00, 0x00,
1788*89c4ff92SAndroid Build Coastguard Worker         0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00,
1789*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
1790*89c4ff92SAndroid Build Coastguard Worker         0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00,
1791*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x66, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1792*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00,
1793*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x07, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09,
1794*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x00, 0x00, 0xF6, 0xFF, 0xFF, 0xFF, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x0A, 0x00,
1795*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0E, 0x00, 0x14, 0x00, 0x00, 0x00,
1796*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x10, 0x00, 0x0E, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00,
1797*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1798*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0A, 0x00, 0x00, 0x00,
1799*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x10, 0x00, 0x08, 0x00,
1800*89c4ff92SAndroid Build Coastguard Worker         0x07, 0x00, 0x0C, 0x00, 0x0A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
1801*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
1802*89c4ff92SAndroid Build Coastguard Worker         0x02, 0x00, 0x00, 0x00
1803*89c4ff92SAndroid Build Coastguard Worker     };
1804*89c4ff92SAndroid Build Coastguard Worker 
1805*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(std::string(mergerModel.begin(), mergerModel.end()));
1806*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1807*89c4ff92SAndroid Build Coastguard Worker 
1808*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo({ 2, 3, 2, 2 }, armnn::DataType::Float32, 0.0f, 0);
1809*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo({ 4, 3, 2, 2 }, armnn::DataType::Float32, 0.0f, 0);
1810*89c4ff92SAndroid Build Coastguard Worker 
1811*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::TensorShape> shapes({inputInfo.GetShape(), inputInfo.GetShape()});
1812*89c4ff92SAndroid Build Coastguard Worker 
1813*89c4ff92SAndroid Build Coastguard Worker     armnn::OriginsDescriptor descriptor =
1814*89c4ff92SAndroid Build Coastguard Worker             armnn::CreateDescriptorForConcatenation(shapes.begin(), shapes.end(), 0);
1815*89c4ff92SAndroid Build Coastguard Worker 
1816*89c4ff92SAndroid Build Coastguard Worker     MergerLayerVerifier verifier("merger", { inputInfo, inputInfo }, { outputInfo }, descriptor);
1817*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1818*89c4ff92SAndroid Build Coastguard Worker }
1819*89c4ff92SAndroid Build Coastguard Worker 
1820*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeConcat")
1821*89c4ff92SAndroid Build Coastguard Worker {
1822*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("concat");
1823*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo = armnn::TensorInfo({2, 3, 2, 2}, armnn::DataType::Float32);
1824*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo({4, 3, 2, 2}, armnn::DataType::Float32);
1825*89c4ff92SAndroid Build Coastguard Worker 
1826*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::TensorShape> shapes({inputInfo.GetShape(), inputInfo.GetShape()});
1827*89c4ff92SAndroid Build Coastguard Worker 
1828*89c4ff92SAndroid Build Coastguard Worker     armnn::OriginsDescriptor descriptor =
1829*89c4ff92SAndroid Build Coastguard Worker         armnn::CreateDescriptorForConcatenation(shapes.begin(), shapes.end(), 0);
1830*89c4ff92SAndroid Build Coastguard Worker 
1831*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1832*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayerOne = network->AddInputLayer(0);
1833*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayerTwo = network->AddInputLayer(1);
1834*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const concatLayer = network->AddConcatLayer(descriptor, layerName.c_str());
1835*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1836*89c4ff92SAndroid Build Coastguard Worker 
1837*89c4ff92SAndroid Build Coastguard Worker     inputLayerOne->GetOutputSlot(0).Connect(concatLayer->GetInputSlot(0));
1838*89c4ff92SAndroid Build Coastguard Worker     inputLayerTwo->GetOutputSlot(0).Connect(concatLayer->GetInputSlot(1));
1839*89c4ff92SAndroid Build Coastguard Worker     concatLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1840*89c4ff92SAndroid Build Coastguard Worker 
1841*89c4ff92SAndroid Build Coastguard Worker     inputLayerOne->GetOutputSlot(0).SetTensorInfo(inputInfo);
1842*89c4ff92SAndroid Build Coastguard Worker     inputLayerTwo->GetOutputSlot(0).SetTensorInfo(inputInfo);
1843*89c4ff92SAndroid Build Coastguard Worker     concatLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
1844*89c4ff92SAndroid Build Coastguard Worker 
1845*89c4ff92SAndroid Build Coastguard Worker     std::string concatLayerNetwork = SerializeNetwork(*network);
1846*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(concatLayerNetwork);
1847*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1848*89c4ff92SAndroid Build Coastguard Worker 
1849*89c4ff92SAndroid Build Coastguard Worker     // NOTE: using the MergerLayerVerifier to ensure that it is a concat layer and not a
1850*89c4ff92SAndroid Build Coastguard Worker     //       merger layer that gets placed into the graph.
1851*89c4ff92SAndroid Build Coastguard Worker     MergerLayerVerifier verifier(layerName, {inputInfo, inputInfo}, {outputInfo}, descriptor);
1852*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1853*89c4ff92SAndroid Build Coastguard Worker }
1854*89c4ff92SAndroid Build Coastguard Worker 
1855*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeMinimum")
1856*89c4ff92SAndroid Build Coastguard Worker {
1857*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("minimum");
1858*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 2, 2, 3 }, armnn::DataType::Float32);
1859*89c4ff92SAndroid Build Coastguard Worker 
1860*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1861*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
1862*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
1863*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
1864*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const minimumLayer = network->AddMinimumLayer(layerName.c_str());
1865*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
1866*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1867*89c4ff92SAndroid Build Coastguard Worker 
1868*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(minimumLayer->GetInputSlot(0));
1869*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(minimumLayer->GetInputSlot(1));
1870*89c4ff92SAndroid Build Coastguard Worker     minimumLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1871*89c4ff92SAndroid Build Coastguard Worker 
1872*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(info);
1873*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(info);
1874*89c4ff92SAndroid Build Coastguard Worker     minimumLayer->GetOutputSlot(0).SetTensorInfo(info);
1875*89c4ff92SAndroid Build Coastguard Worker 
1876*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1877*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1878*89c4ff92SAndroid Build Coastguard Worker 
1879*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {info, info}, {info});
1880*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1881*89c4ff92SAndroid Build Coastguard Worker }
1882*89c4ff92SAndroid Build Coastguard Worker 
1883*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeMultiplication")
1884*89c4ff92SAndroid Build Coastguard Worker {
1885*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("multiplication");
1886*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 5, 2, 3 }, armnn::DataType::Float32);
1887*89c4ff92SAndroid Build Coastguard Worker 
1888*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1889*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
1890*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
1891*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
1892*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const multiplicationLayer = network->AddMultiplicationLayer(layerName.c_str());
1893*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
1894*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1895*89c4ff92SAndroid Build Coastguard Worker 
1896*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(multiplicationLayer->GetInputSlot(0));
1897*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(multiplicationLayer->GetInputSlot(1));
1898*89c4ff92SAndroid Build Coastguard Worker     multiplicationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1899*89c4ff92SAndroid Build Coastguard Worker 
1900*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(info);
1901*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(info);
1902*89c4ff92SAndroid Build Coastguard Worker     multiplicationLayer->GetOutputSlot(0).SetTensorInfo(info);
1903*89c4ff92SAndroid Build Coastguard Worker 
1904*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1905*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1906*89c4ff92SAndroid Build Coastguard Worker 
1907*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {info, info}, {info});
1908*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1909*89c4ff92SAndroid Build Coastguard Worker }
1910*89c4ff92SAndroid Build Coastguard Worker 
1911*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializePrelu")
1912*89c4ff92SAndroid Build Coastguard Worker {
1913*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("prelu");
1914*89c4ff92SAndroid Build Coastguard Worker 
1915*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo ({ 4, 1, 2 }, armnn::DataType::Float32);
1916*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo alphaTensorInfo ({ 5, 4, 3, 1 }, armnn::DataType::Float32);
1917*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo({ 5, 4, 3, 2 }, armnn::DataType::Float32);
1918*89c4ff92SAndroid Build Coastguard Worker 
1919*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1920*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
1921*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const alphaLayer = network->AddInputLayer(1);
1922*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const preluLayer = network->AddPreluLayer(layerName.c_str());
1923*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1924*89c4ff92SAndroid Build Coastguard Worker 
1925*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(preluLayer->GetInputSlot(0));
1926*89c4ff92SAndroid Build Coastguard Worker     alphaLayer->GetOutputSlot(0).Connect(preluLayer->GetInputSlot(1));
1927*89c4ff92SAndroid Build Coastguard Worker     preluLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1928*89c4ff92SAndroid Build Coastguard Worker 
1929*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
1930*89c4ff92SAndroid Build Coastguard Worker     alphaLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
1931*89c4ff92SAndroid Build Coastguard Worker     preluLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1932*89c4ff92SAndroid Build Coastguard Worker 
1933*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1934*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1935*89c4ff92SAndroid Build Coastguard Worker 
1936*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {inputTensorInfo, alphaTensorInfo}, {outputTensorInfo});
1937*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1938*89c4ff92SAndroid Build Coastguard Worker }
1939*89c4ff92SAndroid Build Coastguard Worker 
1940*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeNormalization")
1941*89c4ff92SAndroid Build Coastguard Worker {
1942*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("normalization");
1943*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({2, 1, 2, 2}, armnn::DataType::Float32);
1944*89c4ff92SAndroid Build Coastguard Worker 
1945*89c4ff92SAndroid Build Coastguard Worker     armnn::NormalizationDescriptor desc;
1946*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NCHW;
1947*89c4ff92SAndroid Build Coastguard Worker     desc.m_NormSize = 3;
1948*89c4ff92SAndroid Build Coastguard Worker     desc.m_Alpha = 1;
1949*89c4ff92SAndroid Build Coastguard Worker     desc.m_Beta = 1;
1950*89c4ff92SAndroid Build Coastguard Worker     desc.m_K = 1;
1951*89c4ff92SAndroid Build Coastguard Worker 
1952*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1953*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
1954*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const normalizationLayer = network->AddNormalizationLayer(desc, layerName.c_str());
1955*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1956*89c4ff92SAndroid Build Coastguard Worker 
1957*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(normalizationLayer->GetInputSlot(0));
1958*89c4ff92SAndroid Build Coastguard Worker     normalizationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1959*89c4ff92SAndroid Build Coastguard Worker 
1960*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(info);
1961*89c4ff92SAndroid Build Coastguard Worker     normalizationLayer->GetOutputSlot(0).SetTensorInfo(info);
1962*89c4ff92SAndroid Build Coastguard Worker 
1963*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1964*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1965*89c4ff92SAndroid Build Coastguard Worker 
1966*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::NormalizationDescriptor> verifier(layerName, {info}, {info}, desc);
1967*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1968*89c4ff92SAndroid Build Coastguard Worker }
1969*89c4ff92SAndroid Build Coastguard Worker 
1970*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializePad")
1971*89c4ff92SAndroid Build Coastguard Worker {
1972*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("pad");
1973*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputTensorInfo = armnn::TensorInfo({1, 2, 3, 4}, armnn::DataType::Float32);
1974*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputTensorInfo = armnn::TensorInfo({1, 3, 5, 7}, armnn::DataType::Float32);
1975*89c4ff92SAndroid Build Coastguard Worker 
1976*89c4ff92SAndroid Build Coastguard Worker     armnn::PadDescriptor desc({{0, 0}, {1, 0}, {1, 1}, {1, 2}});
1977*89c4ff92SAndroid Build Coastguard Worker 
1978*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
1979*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
1980*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const padLayer = network->AddPadLayer(desc, layerName.c_str());
1981*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
1982*89c4ff92SAndroid Build Coastguard Worker 
1983*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(padLayer->GetInputSlot(0));
1984*89c4ff92SAndroid Build Coastguard Worker     padLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
1985*89c4ff92SAndroid Build Coastguard Worker 
1986*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
1987*89c4ff92SAndroid Build Coastguard Worker     padLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1988*89c4ff92SAndroid Build Coastguard Worker 
1989*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
1990*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
1991*89c4ff92SAndroid Build Coastguard Worker 
1992*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::PadDescriptor> verifier(layerName,
1993*89c4ff92SAndroid Build Coastguard Worker                                                                    {inputTensorInfo},
1994*89c4ff92SAndroid Build Coastguard Worker                                                                    {outputTensorInfo},
1995*89c4ff92SAndroid Build Coastguard Worker                                                                    desc);
1996*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
1997*89c4ff92SAndroid Build Coastguard Worker }
1998*89c4ff92SAndroid Build Coastguard Worker 
1999*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializePadReflect")
2000*89c4ff92SAndroid Build Coastguard Worker {
2001*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("padReflect");
2002*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputTensorInfo = armnn::TensorInfo({1, 2, 3, 4}, armnn::DataType::Float32);
2003*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputTensorInfo = armnn::TensorInfo({1, 3, 5, 7}, armnn::DataType::Float32);
2004*89c4ff92SAndroid Build Coastguard Worker 
2005*89c4ff92SAndroid Build Coastguard Worker     armnn::PadDescriptor desc({{0, 0}, {1, 0}, {1, 1}, {1, 2}});
2006*89c4ff92SAndroid Build Coastguard Worker     desc.m_PaddingMode = armnn::PaddingMode::Reflect;
2007*89c4ff92SAndroid Build Coastguard Worker 
2008*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2009*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2010*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const padLayer = network->AddPadLayer(desc, layerName.c_str());
2011*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2012*89c4ff92SAndroid Build Coastguard Worker 
2013*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(padLayer->GetInputSlot(0));
2014*89c4ff92SAndroid Build Coastguard Worker     padLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2015*89c4ff92SAndroid Build Coastguard Worker 
2016*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
2017*89c4ff92SAndroid Build Coastguard Worker     padLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2018*89c4ff92SAndroid Build Coastguard Worker 
2019*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2020*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2021*89c4ff92SAndroid Build Coastguard Worker 
2022*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::PadDescriptor> verifier(layerName,
2023*89c4ff92SAndroid Build Coastguard Worker                                                                    {inputTensorInfo},
2024*89c4ff92SAndroid Build Coastguard Worker                                                                    {outputTensorInfo},
2025*89c4ff92SAndroid Build Coastguard Worker                                                                    desc);
2026*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2027*89c4ff92SAndroid Build Coastguard Worker }
2028*89c4ff92SAndroid Build Coastguard Worker 
2029*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("EnsurePadBackwardCompatibility")
2030*89c4ff92SAndroid Build Coastguard Worker {
2031*89c4ff92SAndroid Build Coastguard Worker     // The PadDescriptor is being extended with a float PadValue (so a value other than 0
2032*89c4ff92SAndroid Build Coastguard Worker     // can be used to pad the tensor.
2033*89c4ff92SAndroid Build Coastguard Worker     //
2034*89c4ff92SAndroid Build Coastguard Worker     // This test contains a binary representation of a simple input->pad->output network
2035*89c4ff92SAndroid Build Coastguard Worker     // prior to this change to test that the descriptor has been updated in a backward
2036*89c4ff92SAndroid Build Coastguard Worker     // compatible way with respect to Deserialization of older binary dumps
2037*89c4ff92SAndroid Build Coastguard Worker     const std::vector<uint8_t> padModel =
2038*89c4ff92SAndroid Build Coastguard Worker     {
2039*89c4ff92SAndroid Build Coastguard Worker         0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x10, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x0A, 0x00,
2040*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x1C, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
2041*89c4ff92SAndroid Build Coastguard Worker         0x54, 0x01, 0x00, 0x00, 0x6C, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
2042*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xD0, 0xFE, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x0B,
2043*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x00, 0x00, 0x96, 0xFF, 0xFF, 0xFF, 0x04, 0x00, 0x00, 0x00, 0x9E, 0xFF, 0xFF, 0xFF, 0x04, 0x00,
2044*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x72, 0xFF, 0xFF, 0xFF, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
2045*89c4ff92SAndroid Build Coastguard Worker         0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00,
2046*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2C, 0xFF, 0xFF, 0xFF, 0x01, 0x00, 0x00, 0x00,
2047*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x24, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x16, 0x0C, 0x00, 0x00, 0x00, 0x08, 0x00,
2048*89c4ff92SAndroid Build Coastguard Worker         0x0E, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x4C, 0x00, 0x00, 0x00, 0x0C, 0x00, 0x00, 0x00,
2049*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00,
2050*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
2051*89c4ff92SAndroid Build Coastguard Worker         0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00,
2052*89c4ff92SAndroid Build Coastguard Worker         0x0E, 0x00, 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0E, 0x00, 0x00, 0x00,
2053*89c4ff92SAndroid Build Coastguard Worker         0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00,
2054*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x70, 0x61, 0x64, 0x00, 0x01, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00,
2055*89c4ff92SAndroid Build Coastguard Worker         0x01, 0x00, 0x00, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00,
2056*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x52, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00,
2057*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x05, 0x00,
2058*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00,
2059*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x07, 0x00, 0x08, 0x00, 0x08, 0x00,
2060*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x04, 0x00, 0x00, 0x00, 0xF6, 0xFF, 0xFF, 0xFF, 0x0C, 0x00, 0x00, 0x00,
2061*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x06, 0x00, 0x0A, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00,
2062*89c4ff92SAndroid Build Coastguard Worker         0x0E, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x10, 0x00, 0x0E, 0x00, 0x00, 0x00,
2063*89c4ff92SAndroid Build Coastguard Worker         0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00,
2064*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0C, 0x00, 0x00, 0x00,
2065*89c4ff92SAndroid Build Coastguard Worker         0x08, 0x00, 0x0A, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00,
2066*89c4ff92SAndroid Build Coastguard Worker         0x0A, 0x00, 0x10, 0x00, 0x08, 0x00, 0x07, 0x00, 0x0C, 0x00, 0x0A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
2067*89c4ff92SAndroid Build Coastguard Worker         0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00,
2068*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00
2069*89c4ff92SAndroid Build Coastguard Worker     };
2070*89c4ff92SAndroid Build Coastguard Worker 
2071*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(std::string(padModel.begin(), padModel.end()));
2072*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2073*89c4ff92SAndroid Build Coastguard Worker 
2074*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo({ 1, 2, 3, 4 }, armnn::DataType::Float32, 0.0f, 0);
2075*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo({ 1, 3, 5, 7 }, armnn::DataType::Float32, 0.0f, 0);
2076*89c4ff92SAndroid Build Coastguard Worker 
2077*89c4ff92SAndroid Build Coastguard Worker     armnn::PadDescriptor descriptor({{ 0, 0 }, { 1, 0 }, { 1, 1 }, { 1, 2 }});
2078*89c4ff92SAndroid Build Coastguard Worker 
2079*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::PadDescriptor> verifier("pad", { inputInfo }, { outputInfo }, descriptor);
2080*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2081*89c4ff92SAndroid Build Coastguard Worker }
2082*89c4ff92SAndroid Build Coastguard Worker 
2083*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializePermute")
2084*89c4ff92SAndroid Build Coastguard Worker {
2085*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("permute");
2086*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputTensorInfo({4, 3, 2, 1}, armnn::DataType::Float32);
2087*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputTensorInfo({1, 2, 3, 4}, armnn::DataType::Float32);
2088*89c4ff92SAndroid Build Coastguard Worker 
2089*89c4ff92SAndroid Build Coastguard Worker     armnn::PermuteDescriptor descriptor(armnn::PermutationVector({3, 2, 1, 0}));
2090*89c4ff92SAndroid Build Coastguard Worker 
2091*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2092*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2093*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const permuteLayer = network->AddPermuteLayer(descriptor, layerName.c_str());
2094*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2095*89c4ff92SAndroid Build Coastguard Worker 
2096*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(permuteLayer->GetInputSlot(0));
2097*89c4ff92SAndroid Build Coastguard Worker     permuteLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2098*89c4ff92SAndroid Build Coastguard Worker 
2099*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
2100*89c4ff92SAndroid Build Coastguard Worker     permuteLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2101*89c4ff92SAndroid Build Coastguard Worker 
2102*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2103*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2104*89c4ff92SAndroid Build Coastguard Worker 
2105*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::PermuteDescriptor> verifier(
2106*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputTensorInfo}, {outputTensorInfo}, descriptor);
2107*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2108*89c4ff92SAndroid Build Coastguard Worker }
2109*89c4ff92SAndroid Build Coastguard Worker 
2110*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializePooling2d")
2111*89c4ff92SAndroid Build Coastguard Worker {
2112*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("pooling2d");
2113*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({1, 2, 2, 1}, armnn::DataType::Float32);
2114*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({1, 1, 1, 1}, armnn::DataType::Float32);
2115*89c4ff92SAndroid Build Coastguard Worker 
2116*89c4ff92SAndroid Build Coastguard Worker     armnn::Pooling2dDescriptor desc;
2117*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout          = armnn::DataLayout::NHWC;
2118*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadTop              = 0;
2119*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadBottom           = 0;
2120*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadLeft             = 0;
2121*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadRight            = 0;
2122*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolType            = armnn::PoolingAlgorithm::Average;
2123*89c4ff92SAndroid Build Coastguard Worker     desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor;
2124*89c4ff92SAndroid Build Coastguard Worker     desc.m_PaddingMethod       = armnn::PaddingMethod::Exclude;
2125*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolHeight          = 2;
2126*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolWidth           = 2;
2127*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX             = 2;
2128*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY             = 2;
2129*89c4ff92SAndroid Build Coastguard Worker 
2130*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2131*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2132*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const pooling2dLayer = network->AddPooling2dLayer(desc, layerName.c_str());
2133*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2134*89c4ff92SAndroid Build Coastguard Worker 
2135*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(pooling2dLayer->GetInputSlot(0));
2136*89c4ff92SAndroid Build Coastguard Worker     pooling2dLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2137*89c4ff92SAndroid Build Coastguard Worker 
2138*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2139*89c4ff92SAndroid Build Coastguard Worker     pooling2dLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2140*89c4ff92SAndroid Build Coastguard Worker 
2141*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2142*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2143*89c4ff92SAndroid Build Coastguard Worker 
2144*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::Pooling2dDescriptor> verifier(
2145*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {outputInfo}, desc);
2146*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2147*89c4ff92SAndroid Build Coastguard Worker }
2148*89c4ff92SAndroid Build Coastguard Worker 
2149*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializePooling3d")
2150*89c4ff92SAndroid Build Coastguard Worker {
2151*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("pooling3d");
2152*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({1, 1, 2, 2, 2}, armnn::DataType::Float32);
2153*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({1, 1, 1, 1, 1}, armnn::DataType::Float32);
2154*89c4ff92SAndroid Build Coastguard Worker 
2155*89c4ff92SAndroid Build Coastguard Worker     armnn::Pooling3dDescriptor desc;
2156*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout          = armnn::DataLayout::NDHWC;
2157*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadFront            = 0;
2158*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadBack             = 0;
2159*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadTop              = 0;
2160*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadBottom           = 0;
2161*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadLeft             = 0;
2162*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadRight            = 0;
2163*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolType            = armnn::PoolingAlgorithm::Average;
2164*89c4ff92SAndroid Build Coastguard Worker     desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor;
2165*89c4ff92SAndroid Build Coastguard Worker     desc.m_PaddingMethod       = armnn::PaddingMethod::Exclude;
2166*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolHeight          = 2;
2167*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolWidth           = 2;
2168*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolDepth           = 2;
2169*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX             = 2;
2170*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY             = 2;
2171*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideZ             = 2;
2172*89c4ff92SAndroid Build Coastguard Worker 
2173*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2174*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2175*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const pooling3dLayer = network->AddPooling3dLayer(desc, layerName.c_str());
2176*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2177*89c4ff92SAndroid Build Coastguard Worker 
2178*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(pooling3dLayer->GetInputSlot(0));
2179*89c4ff92SAndroid Build Coastguard Worker     pooling3dLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2180*89c4ff92SAndroid Build Coastguard Worker 
2181*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2182*89c4ff92SAndroid Build Coastguard Worker     pooling3dLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2183*89c4ff92SAndroid Build Coastguard Worker 
2184*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2185*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2186*89c4ff92SAndroid Build Coastguard Worker 
2187*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::Pooling3dDescriptor> verifier(
2188*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {outputInfo}, desc);
2189*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2190*89c4ff92SAndroid Build Coastguard Worker }
2191*89c4ff92SAndroid Build Coastguard Worker 
2192*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeQuantize")
2193*89c4ff92SAndroid Build Coastguard Worker {
2194*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("quantize");
2195*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 2, 2, 3 }, armnn::DataType::Float32);
2196*89c4ff92SAndroid Build Coastguard Worker 
2197*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2198*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2199*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const quantizeLayer = network->AddQuantizeLayer(layerName.c_str());
2200*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2201*89c4ff92SAndroid Build Coastguard Worker 
2202*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(quantizeLayer->GetInputSlot(0));
2203*89c4ff92SAndroid Build Coastguard Worker     quantizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2204*89c4ff92SAndroid Build Coastguard Worker 
2205*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(info);
2206*89c4ff92SAndroid Build Coastguard Worker     quantizeLayer->GetOutputSlot(0).SetTensorInfo(info);
2207*89c4ff92SAndroid Build Coastguard Worker 
2208*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2209*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2210*89c4ff92SAndroid Build Coastguard Worker 
2211*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {info}, {info});
2212*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2213*89c4ff92SAndroid Build Coastguard Worker }
2214*89c4ff92SAndroid Build Coastguard Worker 
2215*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeRank")
2216*89c4ff92SAndroid Build Coastguard Worker {
2217*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("rank");
2218*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({1, 9}, armnn::DataType::Float32);
2219*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({1}, armnn::DataType::Signed32);
2220*89c4ff92SAndroid Build Coastguard Worker 
2221*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2222*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2223*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const rankLayer = network->AddRankLayer(layerName.c_str());
2224*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2225*89c4ff92SAndroid Build Coastguard Worker 
2226*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(rankLayer->GetInputSlot(0));
2227*89c4ff92SAndroid Build Coastguard Worker     rankLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2228*89c4ff92SAndroid Build Coastguard Worker 
2229*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2230*89c4ff92SAndroid Build Coastguard Worker     rankLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2231*89c4ff92SAndroid Build Coastguard Worker 
2232*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2233*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2234*89c4ff92SAndroid Build Coastguard Worker 
2235*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {inputInfo}, {outputInfo});
2236*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2237*89c4ff92SAndroid Build Coastguard Worker }
2238*89c4ff92SAndroid Build Coastguard Worker 
2239*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeReduceSum")
2240*89c4ff92SAndroid Build Coastguard Worker {
2241*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("Reduce_Sum");
2242*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({1, 1, 3, 2}, armnn::DataType::Float32);
2243*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({1, 1, 1, 2}, armnn::DataType::Float32);
2244*89c4ff92SAndroid Build Coastguard Worker 
2245*89c4ff92SAndroid Build Coastguard Worker     armnn::ReduceDescriptor descriptor;
2246*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_vAxis = { 2 };
2247*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ReduceOperation = armnn::ReduceOperation::Sum;
2248*89c4ff92SAndroid Build Coastguard Worker 
2249*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2250*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer   = network->AddInputLayer(0);
2251*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const reduceSumLayer = network->AddReduceLayer(descriptor, layerName.c_str());
2252*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer  = network->AddOutputLayer(0);
2253*89c4ff92SAndroid Build Coastguard Worker 
2254*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(reduceSumLayer->GetInputSlot(0));
2255*89c4ff92SAndroid Build Coastguard Worker     reduceSumLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2256*89c4ff92SAndroid Build Coastguard Worker 
2257*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2258*89c4ff92SAndroid Build Coastguard Worker     reduceSumLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2259*89c4ff92SAndroid Build Coastguard Worker 
2260*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2261*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2262*89c4ff92SAndroid Build Coastguard Worker 
2263*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::ReduceDescriptor> verifier(layerName, {inputInfo}, {outputInfo}, descriptor);
2264*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2265*89c4ff92SAndroid Build Coastguard Worker }
2266*89c4ff92SAndroid Build Coastguard Worker 
2267*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeReshape")
2268*89c4ff92SAndroid Build Coastguard Worker {
2269*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("reshape");
2270*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({1, 9}, armnn::DataType::Float32);
2271*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({3, 3}, armnn::DataType::Float32);
2272*89c4ff92SAndroid Build Coastguard Worker 
2273*89c4ff92SAndroid Build Coastguard Worker     armnn::ReshapeDescriptor descriptor({3, 3});
2274*89c4ff92SAndroid Build Coastguard Worker 
2275*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2276*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2277*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const reshapeLayer = network->AddReshapeLayer(descriptor, layerName.c_str());
2278*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2279*89c4ff92SAndroid Build Coastguard Worker 
2280*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(reshapeLayer->GetInputSlot(0));
2281*89c4ff92SAndroid Build Coastguard Worker     reshapeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2282*89c4ff92SAndroid Build Coastguard Worker 
2283*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2284*89c4ff92SAndroid Build Coastguard Worker     reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2285*89c4ff92SAndroid Build Coastguard Worker 
2286*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2287*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2288*89c4ff92SAndroid Build Coastguard Worker 
2289*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::ReshapeDescriptor> verifier(
2290*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {outputInfo}, descriptor);
2291*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2292*89c4ff92SAndroid Build Coastguard Worker }
2293*89c4ff92SAndroid Build Coastguard Worker 
2294*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeResize")
2295*89c4ff92SAndroid Build Coastguard Worker {
2296*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("resize");
2297*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo({1, 3, 5, 5}, armnn::DataType::Float32);
2298*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo({1, 3, 2, 4}, armnn::DataType::Float32);
2299*89c4ff92SAndroid Build Coastguard Worker 
2300*89c4ff92SAndroid Build Coastguard Worker     armnn::ResizeDescriptor desc;
2301*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetWidth  = 4;
2302*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetHeight = 2;
2303*89c4ff92SAndroid Build Coastguard Worker     desc.m_Method       = armnn::ResizeMethod::NearestNeighbor;
2304*89c4ff92SAndroid Build Coastguard Worker     desc.m_AlignCorners = true;
2305*89c4ff92SAndroid Build Coastguard Worker     desc.m_HalfPixelCenters = true;
2306*89c4ff92SAndroid Build Coastguard Worker 
2307*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2308*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2309*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const resizeLayer = network->AddResizeLayer(desc, layerName.c_str());
2310*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2311*89c4ff92SAndroid Build Coastguard Worker 
2312*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(resizeLayer->GetInputSlot(0));
2313*89c4ff92SAndroid Build Coastguard Worker     resizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2314*89c4ff92SAndroid Build Coastguard Worker 
2315*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2316*89c4ff92SAndroid Build Coastguard Worker     resizeLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2317*89c4ff92SAndroid Build Coastguard Worker 
2318*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2319*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2320*89c4ff92SAndroid Build Coastguard Worker 
2321*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::ResizeDescriptor> verifier(layerName, {inputInfo}, {outputInfo}, desc);
2322*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2323*89c4ff92SAndroid Build Coastguard Worker }
2324*89c4ff92SAndroid Build Coastguard Worker 
2325*89c4ff92SAndroid Build Coastguard Worker class ResizeBilinearLayerVerifier : public LayerVerifierBaseWithDescriptor<armnn::ResizeDescriptor>
2326*89c4ff92SAndroid Build Coastguard Worker {
2327*89c4ff92SAndroid Build Coastguard Worker public:
ResizeBilinearLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const armnn::ResizeDescriptor & descriptor)2328*89c4ff92SAndroid Build Coastguard Worker     ResizeBilinearLayerVerifier(const std::string& layerName,
2329*89c4ff92SAndroid Build Coastguard Worker                                 const std::vector<armnn::TensorInfo>& inputInfos,
2330*89c4ff92SAndroid Build Coastguard Worker                                 const std::vector<armnn::TensorInfo>& outputInfos,
2331*89c4ff92SAndroid Build Coastguard Worker                                 const armnn::ResizeDescriptor& descriptor)
2332*89c4ff92SAndroid Build Coastguard Worker         : LayerVerifierBaseWithDescriptor<armnn::ResizeDescriptor>(
2333*89c4ff92SAndroid Build Coastguard Worker             layerName, inputInfos, outputInfos, descriptor) {}
2334*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)2335*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
2336*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
2337*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
2338*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
2339*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override
2340*89c4ff92SAndroid Build Coastguard Worker     {
2341*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(descriptor, constants, id);
2342*89c4ff92SAndroid Build Coastguard Worker         switch (layer->GetType())
2343*89c4ff92SAndroid Build Coastguard Worker         {
2344*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Input: break;
2345*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Output: break;
2346*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Resize:
2347*89c4ff92SAndroid Build Coastguard Worker             {
2348*89c4ff92SAndroid Build Coastguard Worker                 VerifyNameAndConnections(layer, name);
2349*89c4ff92SAndroid Build Coastguard Worker                 const armnn::ResizeDescriptor& layerDescriptor =
2350*89c4ff92SAndroid Build Coastguard Worker                         static_cast<const armnn::ResizeDescriptor&>(descriptor);
2351*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_Method             == armnn::ResizeMethod::Bilinear);
2352*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_TargetWidth        == m_Descriptor.m_TargetWidth);
2353*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_TargetHeight       == m_Descriptor.m_TargetHeight);
2354*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_DataLayout         == m_Descriptor.m_DataLayout);
2355*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_AlignCorners       == m_Descriptor.m_AlignCorners);
2356*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_HalfPixelCenters   == m_Descriptor.m_HalfPixelCenters);
2357*89c4ff92SAndroid Build Coastguard Worker                 break;
2358*89c4ff92SAndroid Build Coastguard Worker             }
2359*89c4ff92SAndroid Build Coastguard Worker             default:
2360*89c4ff92SAndroid Build Coastguard Worker             {
2361*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception("Unexpected layer type in test model. ResizeBiliniar "
2362*89c4ff92SAndroid Build Coastguard Worker                                        "should have translated to Resize");
2363*89c4ff92SAndroid Build Coastguard Worker             }
2364*89c4ff92SAndroid Build Coastguard Worker         }
2365*89c4ff92SAndroid Build Coastguard Worker     }
2366*89c4ff92SAndroid Build Coastguard Worker };
2367*89c4ff92SAndroid Build Coastguard Worker 
2368*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeResizeBilinear")
2369*89c4ff92SAndroid Build Coastguard Worker {
2370*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("resizeBilinear");
2371*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo({1, 3, 5, 5}, armnn::DataType::Float32);
2372*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo({1, 3, 2, 4}, armnn::DataType::Float32);
2373*89c4ff92SAndroid Build Coastguard Worker 
2374*89c4ff92SAndroid Build Coastguard Worker     armnn::ResizeDescriptor desc;
2375*89c4ff92SAndroid Build Coastguard Worker     desc.m_Method = armnn::ResizeMethod::Bilinear;
2376*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetWidth  = 4u;
2377*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetHeight = 2u;
2378*89c4ff92SAndroid Build Coastguard Worker     desc.m_AlignCorners = true;
2379*89c4ff92SAndroid Build Coastguard Worker     desc.m_HalfPixelCenters = true;
2380*89c4ff92SAndroid Build Coastguard Worker 
2381*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2382*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2383*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const resizeLayer = network->AddResizeLayer(desc, layerName.c_str());
2384*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2385*89c4ff92SAndroid Build Coastguard Worker 
2386*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(resizeLayer->GetInputSlot(0));
2387*89c4ff92SAndroid Build Coastguard Worker     resizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2388*89c4ff92SAndroid Build Coastguard Worker 
2389*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2390*89c4ff92SAndroid Build Coastguard Worker     resizeLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2391*89c4ff92SAndroid Build Coastguard Worker 
2392*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2393*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2394*89c4ff92SAndroid Build Coastguard Worker 
2395*89c4ff92SAndroid Build Coastguard Worker     ResizeBilinearLayerVerifier verifier(layerName, {inputInfo}, {outputInfo}, desc);
2396*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2397*89c4ff92SAndroid Build Coastguard Worker }
2398*89c4ff92SAndroid Build Coastguard Worker 
2399*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("EnsureResizeBilinearBackwardCompatibility")
2400*89c4ff92SAndroid Build Coastguard Worker {
2401*89c4ff92SAndroid Build Coastguard Worker     // The hex data below is a flat buffer containing a simple network with an input,
2402*89c4ff92SAndroid Build Coastguard Worker     // a ResizeBilinearLayer (now deprecated and removed) and an output
2403*89c4ff92SAndroid Build Coastguard Worker     //
2404*89c4ff92SAndroid Build Coastguard Worker     // This test verifies that we can still deserialize this old-style model by replacing
2405*89c4ff92SAndroid Build Coastguard Worker     // the ResizeBilinearLayer with an equivalent ResizeLayer
2406*89c4ff92SAndroid Build Coastguard Worker     const std::vector<uint8_t> resizeBilinearModel =
2407*89c4ff92SAndroid Build Coastguard Worker     {
2408*89c4ff92SAndroid Build Coastguard Worker         0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x10, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x0A, 0x00,
2409*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x1C, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
2410*89c4ff92SAndroid Build Coastguard Worker         0x50, 0x01, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
2411*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xD4, 0xFE, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x0B,
2412*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x00, 0x00, 0xC2, 0xFE, 0xFF, 0xFF, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00,
2413*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8A, 0xFF, 0xFF, 0xFF, 0x02, 0x00, 0x00, 0x00,
2414*89c4ff92SAndroid Build Coastguard Worker         0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00,
2415*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
2416*89c4ff92SAndroid Build Coastguard Worker         0x38, 0xFF, 0xFF, 0xFF, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x30, 0xFF, 0xFF, 0xFF, 0x00, 0x00,
2417*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x1A, 0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0E, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00,
2418*89c4ff92SAndroid Build Coastguard Worker         0x34, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x12, 0x00, 0x08, 0x00, 0x0C, 0x00,
2419*89c4ff92SAndroid Build Coastguard Worker         0x07, 0x00, 0x0A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
2420*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x0E, 0x00, 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0E, 0x00,
2421*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x1C, 0x00, 0x00, 0x00,
2422*89c4ff92SAndroid Build Coastguard Worker         0x20, 0x00, 0x00, 0x00, 0x0E, 0x00, 0x00, 0x00, 0x72, 0x65, 0x73, 0x69, 0x7A, 0x65, 0x42, 0x69, 0x6C, 0x69,
2423*89c4ff92SAndroid Build Coastguard Worker         0x6E, 0x65, 0x61, 0x72, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
2424*89c4ff92SAndroid Build Coastguard Worker         0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00,
2425*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x52, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
2426*89c4ff92SAndroid Build Coastguard Worker         0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00,
2427*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
2428*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x07, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
2429*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x09, 0x04, 0x00, 0x00, 0x00, 0xF6, 0xFF, 0xFF, 0xFF, 0x0C, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00,
2430*89c4ff92SAndroid Build Coastguard Worker         0x0A, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0E, 0x00, 0x14, 0x00,
2431*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0C, 0x00, 0x10, 0x00, 0x0E, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
2432*89c4ff92SAndroid Build Coastguard Worker         0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
2433*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0C, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0A, 0x00,
2434*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x10, 0x00,
2435*89c4ff92SAndroid Build Coastguard Worker         0x08, 0x00, 0x07, 0x00, 0x0C, 0x00, 0x0A, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00,
2436*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x05, 0x00,
2437*89c4ff92SAndroid Build Coastguard Worker         0x00, 0x00, 0x05, 0x00, 0x00, 0x00
2438*89c4ff92SAndroid Build Coastguard Worker     };
2439*89c4ff92SAndroid Build Coastguard Worker 
2440*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork =
2441*89c4ff92SAndroid Build Coastguard Worker         DeserializeNetwork(std::string(resizeBilinearModel.begin(), resizeBilinearModel.end()));
2442*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2443*89c4ff92SAndroid Build Coastguard Worker 
2444*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo({1, 3, 5, 5}, armnn::DataType::Float32, 0.0f, 0);
2445*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo({1, 3, 2, 4}, armnn::DataType::Float32, 0.0f, 0);
2446*89c4ff92SAndroid Build Coastguard Worker 
2447*89c4ff92SAndroid Build Coastguard Worker     armnn::ResizeDescriptor descriptor;
2448*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_TargetWidth  = 4u;
2449*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_TargetHeight = 2u;
2450*89c4ff92SAndroid Build Coastguard Worker 
2451*89c4ff92SAndroid Build Coastguard Worker     ResizeBilinearLayerVerifier verifier("resizeBilinear", { inputInfo }, { outputInfo }, descriptor);
2452*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2453*89c4ff92SAndroid Build Coastguard Worker }
2454*89c4ff92SAndroid Build Coastguard Worker 
2455*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeShape")
2456*89c4ff92SAndroid Build Coastguard Worker {
2457*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("shape");
2458*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({1, 3, 3, 1}, armnn::DataType::Signed32);
2459*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 4 }, armnn::DataType::Signed32);
2460*89c4ff92SAndroid Build Coastguard Worker 
2461*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2462*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2463*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const shapeLayer = network->AddShapeLayer(layerName.c_str());
2464*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2465*89c4ff92SAndroid Build Coastguard Worker 
2466*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(shapeLayer->GetInputSlot(0));
2467*89c4ff92SAndroid Build Coastguard Worker     shapeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2468*89c4ff92SAndroid Build Coastguard Worker 
2469*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2470*89c4ff92SAndroid Build Coastguard Worker     shapeLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2471*89c4ff92SAndroid Build Coastguard Worker 
2472*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2473*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2474*89c4ff92SAndroid Build Coastguard Worker 
2475*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {inputInfo}, {outputInfo});
2476*89c4ff92SAndroid Build Coastguard Worker 
2477*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2478*89c4ff92SAndroid Build Coastguard Worker }
2479*89c4ff92SAndroid Build Coastguard Worker 
2480*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeSlice")
2481*89c4ff92SAndroid Build Coastguard Worker {
2482*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName{"slice"};
2483*89c4ff92SAndroid Build Coastguard Worker 
2484*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo({3, 2, 3, 1}, armnn::DataType::Float32);
2485*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo({2, 2, 2, 1}, armnn::DataType::Float32);
2486*89c4ff92SAndroid Build Coastguard Worker 
2487*89c4ff92SAndroid Build Coastguard Worker     armnn::SliceDescriptor descriptor({ 0, 0, 1, 0}, {2, 2, 2, 1});
2488*89c4ff92SAndroid Build Coastguard Worker 
2489*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2490*89c4ff92SAndroid Build Coastguard Worker 
2491*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer  = network->AddInputLayer(0);
2492*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const sliceLayer  = network->AddSliceLayer(descriptor, layerName.c_str());
2493*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2494*89c4ff92SAndroid Build Coastguard Worker 
2495*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(sliceLayer->GetInputSlot(0));
2496*89c4ff92SAndroid Build Coastguard Worker     sliceLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2497*89c4ff92SAndroid Build Coastguard Worker 
2498*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2499*89c4ff92SAndroid Build Coastguard Worker     sliceLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2500*89c4ff92SAndroid Build Coastguard Worker 
2501*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2502*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2503*89c4ff92SAndroid Build Coastguard Worker 
2504*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::SliceDescriptor> verifier(layerName, {inputInfo}, {outputInfo}, descriptor);
2505*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2506*89c4ff92SAndroid Build Coastguard Worker }
2507*89c4ff92SAndroid Build Coastguard Worker 
2508*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeSoftmax")
2509*89c4ff92SAndroid Build Coastguard Worker {
2510*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("softmax");
2511*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({1, 10}, armnn::DataType::Float32);
2512*89c4ff92SAndroid Build Coastguard Worker 
2513*89c4ff92SAndroid Build Coastguard Worker     armnn::SoftmaxDescriptor descriptor;
2514*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Beta = 1.0f;
2515*89c4ff92SAndroid Build Coastguard Worker 
2516*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2517*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer   = network->AddInputLayer(0);
2518*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const softmaxLayer = network->AddSoftmaxLayer(descriptor, layerName.c_str());
2519*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer  = network->AddOutputLayer(0);
2520*89c4ff92SAndroid Build Coastguard Worker 
2521*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(softmaxLayer->GetInputSlot(0));
2522*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2523*89c4ff92SAndroid Build Coastguard Worker 
2524*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(info);
2525*89c4ff92SAndroid Build Coastguard Worker     softmaxLayer->GetOutputSlot(0).SetTensorInfo(info);
2526*89c4ff92SAndroid Build Coastguard Worker 
2527*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2528*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2529*89c4ff92SAndroid Build Coastguard Worker 
2530*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::SoftmaxDescriptor> verifier(layerName, {info}, {info}, descriptor);
2531*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2532*89c4ff92SAndroid Build Coastguard Worker }
2533*89c4ff92SAndroid Build Coastguard Worker 
2534*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeSpaceToBatchNd")
2535*89c4ff92SAndroid Build Coastguard Worker {
2536*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("spaceToBatchNd");
2537*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo({2, 1, 2, 4}, armnn::DataType::Float32);
2538*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({8, 1, 1, 3}, armnn::DataType::Float32);
2539*89c4ff92SAndroid Build Coastguard Worker 
2540*89c4ff92SAndroid Build Coastguard Worker     armnn::SpaceToBatchNdDescriptor desc;
2541*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NCHW;
2542*89c4ff92SAndroid Build Coastguard Worker     desc.m_BlockShape = {2, 2};
2543*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadList = {{0, 0}, {2, 0}};
2544*89c4ff92SAndroid Build Coastguard Worker 
2545*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2546*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2547*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const spaceToBatchNdLayer = network->AddSpaceToBatchNdLayer(desc, layerName.c_str());
2548*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2549*89c4ff92SAndroid Build Coastguard Worker 
2550*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(spaceToBatchNdLayer->GetInputSlot(0));
2551*89c4ff92SAndroid Build Coastguard Worker     spaceToBatchNdLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2552*89c4ff92SAndroid Build Coastguard Worker 
2553*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2554*89c4ff92SAndroid Build Coastguard Worker     spaceToBatchNdLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2555*89c4ff92SAndroid Build Coastguard Worker 
2556*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2557*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2558*89c4ff92SAndroid Build Coastguard Worker 
2559*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::SpaceToBatchNdDescriptor> verifier(
2560*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {outputInfo}, desc);
2561*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2562*89c4ff92SAndroid Build Coastguard Worker }
2563*89c4ff92SAndroid Build Coastguard Worker 
2564*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeSpaceToDepth")
2565*89c4ff92SAndroid Build Coastguard Worker {
2566*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("spaceToDepth");
2567*89c4ff92SAndroid Build Coastguard Worker 
2568*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 1, 16, 8,  3 }, armnn::DataType::Float32);
2569*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1,  8, 4, 12 }, armnn::DataType::Float32);
2570*89c4ff92SAndroid Build Coastguard Worker 
2571*89c4ff92SAndroid Build Coastguard Worker     armnn::SpaceToDepthDescriptor desc;
2572*89c4ff92SAndroid Build Coastguard Worker     desc.m_BlockSize  = 2;
2573*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
2574*89c4ff92SAndroid Build Coastguard Worker 
2575*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2576*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer        = network->AddInputLayer(0);
2577*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const spaceToDepthLayer = network->AddSpaceToDepthLayer(desc, layerName.c_str());
2578*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer       = network->AddOutputLayer(0);
2579*89c4ff92SAndroid Build Coastguard Worker 
2580*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(spaceToDepthLayer->GetInputSlot(0));
2581*89c4ff92SAndroid Build Coastguard Worker     spaceToDepthLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2582*89c4ff92SAndroid Build Coastguard Worker 
2583*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2584*89c4ff92SAndroid Build Coastguard Worker     spaceToDepthLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2585*89c4ff92SAndroid Build Coastguard Worker 
2586*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2587*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2588*89c4ff92SAndroid Build Coastguard Worker 
2589*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::SpaceToDepthDescriptor> verifier(
2590*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {outputInfo}, desc);
2591*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2592*89c4ff92SAndroid Build Coastguard Worker }
2593*89c4ff92SAndroid Build Coastguard Worker 
2594*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeSplitter")
2595*89c4ff92SAndroid Build Coastguard Worker {
2596*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numViews = 3;
2597*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numDimensions = 4;
2598*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputShape[] = {1, 18, 4, 4};
2599*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputShape[] = {1, 6, 4, 4};
2600*89c4ff92SAndroid Build Coastguard Worker 
2601*89c4ff92SAndroid Build Coastguard Worker     // This is modelled on how the caffe parser sets up a splitter layer to partition an input along dimension one.
2602*89c4ff92SAndroid Build Coastguard Worker     unsigned int splitterDimSizes[4] = {static_cast<unsigned int>(inputShape[0]),
2603*89c4ff92SAndroid Build Coastguard Worker                                         static_cast<unsigned int>(inputShape[1]),
2604*89c4ff92SAndroid Build Coastguard Worker                                         static_cast<unsigned int>(inputShape[2]),
2605*89c4ff92SAndroid Build Coastguard Worker                                         static_cast<unsigned int>(inputShape[3])};
2606*89c4ff92SAndroid Build Coastguard Worker     splitterDimSizes[1] /= numViews;
2607*89c4ff92SAndroid Build Coastguard Worker     armnn::ViewsDescriptor desc(numViews, numDimensions);
2608*89c4ff92SAndroid Build Coastguard Worker 
2609*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int g = 0; g < numViews; ++g)
2610*89c4ff92SAndroid Build Coastguard Worker     {
2611*89c4ff92SAndroid Build Coastguard Worker         desc.SetViewOriginCoord(g, 1, splitterDimSizes[1] * g);
2612*89c4ff92SAndroid Build Coastguard Worker 
2613*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int dimIdx=0; dimIdx < 4; dimIdx++)
2614*89c4ff92SAndroid Build Coastguard Worker         {
2615*89c4ff92SAndroid Build Coastguard Worker             desc.SetViewSize(g, dimIdx, splitterDimSizes[dimIdx]);
2616*89c4ff92SAndroid Build Coastguard Worker         }
2617*89c4ff92SAndroid Build Coastguard Worker     }
2618*89c4ff92SAndroid Build Coastguard Worker 
2619*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("splitter");
2620*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo(numDimensions, inputShape, armnn::DataType::Float32);
2621*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo(numDimensions, outputShape, armnn::DataType::Float32);
2622*89c4ff92SAndroid Build Coastguard Worker 
2623*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2624*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2625*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const splitterLayer = network->AddSplitterLayer(desc, layerName.c_str());
2626*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer0 = network->AddOutputLayer(0);
2627*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer1 = network->AddOutputLayer(1);
2628*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer2 = network->AddOutputLayer(2);
2629*89c4ff92SAndroid Build Coastguard Worker 
2630*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(splitterLayer->GetInputSlot(0));
2631*89c4ff92SAndroid Build Coastguard Worker     splitterLayer->GetOutputSlot(0).Connect(outputLayer0->GetInputSlot(0));
2632*89c4ff92SAndroid Build Coastguard Worker     splitterLayer->GetOutputSlot(1).Connect(outputLayer1->GetInputSlot(0));
2633*89c4ff92SAndroid Build Coastguard Worker     splitterLayer->GetOutputSlot(2).Connect(outputLayer2->GetInputSlot(0));
2634*89c4ff92SAndroid Build Coastguard Worker 
2635*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2636*89c4ff92SAndroid Build Coastguard Worker     splitterLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2637*89c4ff92SAndroid Build Coastguard Worker     splitterLayer->GetOutputSlot(1).SetTensorInfo(outputInfo);
2638*89c4ff92SAndroid Build Coastguard Worker     splitterLayer->GetOutputSlot(2).SetTensorInfo(outputInfo);
2639*89c4ff92SAndroid Build Coastguard Worker 
2640*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2641*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2642*89c4ff92SAndroid Build Coastguard Worker 
2643*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::ViewsDescriptor> verifier(
2644*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {outputInfo, outputInfo, outputInfo}, desc);
2645*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2646*89c4ff92SAndroid Build Coastguard Worker }
2647*89c4ff92SAndroid Build Coastguard Worker 
2648*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeStack")
2649*89c4ff92SAndroid Build Coastguard Worker {
2650*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("stack");
2651*89c4ff92SAndroid Build Coastguard Worker 
2652*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo ({4, 3, 5}, armnn::DataType::Float32);
2653*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo({4, 3, 2, 5}, armnn::DataType::Float32);
2654*89c4ff92SAndroid Build Coastguard Worker 
2655*89c4ff92SAndroid Build Coastguard Worker     armnn::StackDescriptor descriptor(2, 2, {4, 3, 5});
2656*89c4ff92SAndroid Build Coastguard Worker 
2657*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2658*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(0);
2659*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer2 = network->AddInputLayer(1);
2660*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const stackLayer = network->AddStackLayer(descriptor, layerName.c_str());
2661*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2662*89c4ff92SAndroid Build Coastguard Worker 
2663*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(stackLayer->GetInputSlot(0));
2664*89c4ff92SAndroid Build Coastguard Worker     inputLayer2->GetOutputSlot(0).Connect(stackLayer->GetInputSlot(1));
2665*89c4ff92SAndroid Build Coastguard Worker     stackLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2666*89c4ff92SAndroid Build Coastguard Worker 
2667*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
2668*89c4ff92SAndroid Build Coastguard Worker     inputLayer2->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
2669*89c4ff92SAndroid Build Coastguard Worker     stackLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2670*89c4ff92SAndroid Build Coastguard Worker 
2671*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2672*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2673*89c4ff92SAndroid Build Coastguard Worker 
2674*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::StackDescriptor> verifier(
2675*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputTensorInfo, inputTensorInfo}, {outputTensorInfo}, descriptor);
2676*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2677*89c4ff92SAndroid Build Coastguard Worker }
2678*89c4ff92SAndroid Build Coastguard Worker 
2679*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeStandIn")
2680*89c4ff92SAndroid Build Coastguard Worker {
2681*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("standIn");
2682*89c4ff92SAndroid Build Coastguard Worker 
2683*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo tensorInfo({ 1u }, armnn::DataType::Float32);
2684*89c4ff92SAndroid Build Coastguard Worker     armnn::StandInDescriptor descriptor(2u, 2u);
2685*89c4ff92SAndroid Build Coastguard Worker 
2686*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2687*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0  = network->AddInputLayer(0);
2688*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1  = network->AddInputLayer(1);
2689*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const standInLayer = network->AddStandInLayer(descriptor, layerName.c_str());
2690*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer0 = network->AddOutputLayer(0);
2691*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer1 = network->AddOutputLayer(1);
2692*89c4ff92SAndroid Build Coastguard Worker 
2693*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(standInLayer->GetInputSlot(0));
2694*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(tensorInfo);
2695*89c4ff92SAndroid Build Coastguard Worker 
2696*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(standInLayer->GetInputSlot(1));
2697*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(tensorInfo);
2698*89c4ff92SAndroid Build Coastguard Worker 
2699*89c4ff92SAndroid Build Coastguard Worker     standInLayer->GetOutputSlot(0).Connect(outputLayer0->GetInputSlot(0));
2700*89c4ff92SAndroid Build Coastguard Worker     standInLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
2701*89c4ff92SAndroid Build Coastguard Worker 
2702*89c4ff92SAndroid Build Coastguard Worker     standInLayer->GetOutputSlot(1).Connect(outputLayer1->GetInputSlot(0));
2703*89c4ff92SAndroid Build Coastguard Worker     standInLayer->GetOutputSlot(1).SetTensorInfo(tensorInfo);
2704*89c4ff92SAndroid Build Coastguard Worker 
2705*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2706*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2707*89c4ff92SAndroid Build Coastguard Worker 
2708*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::StandInDescriptor> verifier(
2709*89c4ff92SAndroid Build Coastguard Worker             layerName, { tensorInfo, tensorInfo }, { tensorInfo, tensorInfo }, descriptor);
2710*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2711*89c4ff92SAndroid Build Coastguard Worker }
2712*89c4ff92SAndroid Build Coastguard Worker 
2713*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeStridedSlice")
2714*89c4ff92SAndroid Build Coastguard Worker {
2715*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("stridedSlice");
2716*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo = armnn::TensorInfo({3, 2, 3, 1}, armnn::DataType::Float32);
2717*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo({3, 1}, armnn::DataType::Float32);
2718*89c4ff92SAndroid Build Coastguard Worker 
2719*89c4ff92SAndroid Build Coastguard Worker     armnn::StridedSliceDescriptor desc({0, 0, 1, 0}, {1, 1, 1, 1}, {1, 1, 1, 1});
2720*89c4ff92SAndroid Build Coastguard Worker     desc.m_EndMask = (1 << 4) - 1;
2721*89c4ff92SAndroid Build Coastguard Worker     desc.m_ShrinkAxisMask = (1 << 1) | (1 << 2);
2722*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NCHW;
2723*89c4ff92SAndroid Build Coastguard Worker 
2724*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2725*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2726*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const stridedSliceLayer = network->AddStridedSliceLayer(desc, layerName.c_str());
2727*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2728*89c4ff92SAndroid Build Coastguard Worker 
2729*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(stridedSliceLayer->GetInputSlot(0));
2730*89c4ff92SAndroid Build Coastguard Worker     stridedSliceLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2731*89c4ff92SAndroid Build Coastguard Worker 
2732*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2733*89c4ff92SAndroid Build Coastguard Worker     stridedSliceLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2734*89c4ff92SAndroid Build Coastguard Worker 
2735*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2736*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2737*89c4ff92SAndroid Build Coastguard Worker 
2738*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::StridedSliceDescriptor> verifier(
2739*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {outputInfo}, desc);
2740*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2741*89c4ff92SAndroid Build Coastguard Worker }
2742*89c4ff92SAndroid Build Coastguard Worker 
2743*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeSubtraction")
2744*89c4ff92SAndroid Build Coastguard Worker {
2745*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("subtraction");
2746*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 4 }, armnn::DataType::Float32);
2747*89c4ff92SAndroid Build Coastguard Worker 
2748*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2749*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
2750*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
2751*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
2752*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const subtractionLayer = network->AddSubtractionLayer(layerName.c_str());
2753*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
2754*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2755*89c4ff92SAndroid Build Coastguard Worker 
2756*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).Connect(subtractionLayer->GetInputSlot(0));
2757*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).Connect(subtractionLayer->GetInputSlot(1));
2758*89c4ff92SAndroid Build Coastguard Worker     subtractionLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2759*89c4ff92SAndroid Build Coastguard Worker 
2760*89c4ff92SAndroid Build Coastguard Worker     inputLayer0->GetOutputSlot(0).SetTensorInfo(info);
2761*89c4ff92SAndroid Build Coastguard Worker     inputLayer1->GetOutputSlot(0).SetTensorInfo(info);
2762*89c4ff92SAndroid Build Coastguard Worker     subtractionLayer->GetOutputSlot(0).SetTensorInfo(info);
2763*89c4ff92SAndroid Build Coastguard Worker 
2764*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2765*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2766*89c4ff92SAndroid Build Coastguard Worker 
2767*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase verifier(layerName, {info, info}, {info});
2768*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2769*89c4ff92SAndroid Build Coastguard Worker }
2770*89c4ff92SAndroid Build Coastguard Worker 
2771*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeSwitch")
2772*89c4ff92SAndroid Build Coastguard Worker {
2773*89c4ff92SAndroid Build Coastguard Worker     class SwitchLayerVerifier : public LayerVerifierBase
2774*89c4ff92SAndroid Build Coastguard Worker     {
2775*89c4ff92SAndroid Build Coastguard Worker     public:
SwitchLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos)2776*89c4ff92SAndroid Build Coastguard Worker         SwitchLayerVerifier(const std::string& layerName,
2777*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<armnn::TensorInfo>& inputInfos,
2778*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<armnn::TensorInfo>& outputInfos)
2779*89c4ff92SAndroid Build Coastguard Worker                 : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
2780*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)2781*89c4ff92SAndroid Build Coastguard Worker         void ExecuteStrategy(const armnn::IConnectableLayer* layer,
2782*89c4ff92SAndroid Build Coastguard Worker                              const armnn::BaseDescriptor& descriptor,
2783*89c4ff92SAndroid Build Coastguard Worker                              const std::vector<armnn::ConstTensor>& constants,
2784*89c4ff92SAndroid Build Coastguard Worker                              const char* name,
2785*89c4ff92SAndroid Build Coastguard Worker                              const armnn::LayerBindingId id = 0) override
2786*89c4ff92SAndroid Build Coastguard Worker         {
2787*89c4ff92SAndroid Build Coastguard Worker             armnn::IgnoreUnused(descriptor, constants, id);
2788*89c4ff92SAndroid Build Coastguard Worker             switch (layer->GetType())
2789*89c4ff92SAndroid Build Coastguard Worker             {
2790*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Input: break;
2791*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Output: break;
2792*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Constant: break;
2793*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Switch:
2794*89c4ff92SAndroid Build Coastguard Worker                 {
2795*89c4ff92SAndroid Build Coastguard Worker                     VerifyNameAndConnections(layer, name);
2796*89c4ff92SAndroid Build Coastguard Worker                     break;
2797*89c4ff92SAndroid Build Coastguard Worker                 }
2798*89c4ff92SAndroid Build Coastguard Worker                 default:
2799*89c4ff92SAndroid Build Coastguard Worker                 {
2800*89c4ff92SAndroid Build Coastguard Worker                     throw armnn::Exception("Unexpected layer type in Switch test model");
2801*89c4ff92SAndroid Build Coastguard Worker                 }
2802*89c4ff92SAndroid Build Coastguard Worker             }
2803*89c4ff92SAndroid Build Coastguard Worker         }
2804*89c4ff92SAndroid Build Coastguard Worker     };
2805*89c4ff92SAndroid Build Coastguard Worker 
2806*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("switch");
2807*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 4 }, armnn::DataType::Float32, 0.0f, 0, true);
2808*89c4ff92SAndroid Build Coastguard Worker 
2809*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> constantData = GenerateRandomData<float>(info.GetNumElements());
2810*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor constTensor(info, constantData);
2811*89c4ff92SAndroid Build Coastguard Worker 
2812*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2813*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2814*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const constantLayer = network->AddConstantLayer(constTensor, "constant");
2815*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const switchLayer = network->AddSwitchLayer(layerName.c_str());
2816*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const trueOutputLayer = network->AddOutputLayer(0);
2817*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const falseOutputLayer = network->AddOutputLayer(1);
2818*89c4ff92SAndroid Build Coastguard Worker 
2819*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(0));
2820*89c4ff92SAndroid Build Coastguard Worker     constantLayer->GetOutputSlot(0).Connect(switchLayer->GetInputSlot(1));
2821*89c4ff92SAndroid Build Coastguard Worker     switchLayer->GetOutputSlot(0).Connect(trueOutputLayer->GetInputSlot(0));
2822*89c4ff92SAndroid Build Coastguard Worker     switchLayer->GetOutputSlot(1).Connect(falseOutputLayer->GetInputSlot(0));
2823*89c4ff92SAndroid Build Coastguard Worker 
2824*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(info);
2825*89c4ff92SAndroid Build Coastguard Worker     constantLayer->GetOutputSlot(0).SetTensorInfo(info);
2826*89c4ff92SAndroid Build Coastguard Worker     switchLayer->GetOutputSlot(0).SetTensorInfo(info);
2827*89c4ff92SAndroid Build Coastguard Worker     switchLayer->GetOutputSlot(1).SetTensorInfo(info);
2828*89c4ff92SAndroid Build Coastguard Worker 
2829*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2830*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2831*89c4ff92SAndroid Build Coastguard Worker 
2832*89c4ff92SAndroid Build Coastguard Worker     SwitchLayerVerifier verifier(layerName, {info, info}, {info, info});
2833*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2834*89c4ff92SAndroid Build Coastguard Worker }
2835*89c4ff92SAndroid Build Coastguard Worker 
2836*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeTranspose")
2837*89c4ff92SAndroid Build Coastguard Worker {
2838*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("transpose");
2839*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputTensorInfo({4, 3, 2, 1}, armnn::DataType::Float32);
2840*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputTensorInfo({1, 2, 3, 4}, armnn::DataType::Float32);
2841*89c4ff92SAndroid Build Coastguard Worker 
2842*89c4ff92SAndroid Build Coastguard Worker     armnn::TransposeDescriptor descriptor(armnn::PermutationVector({3, 2, 1, 0}));
2843*89c4ff92SAndroid Build Coastguard Worker 
2844*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2845*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0);
2846*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const transposeLayer = network->AddTransposeLayer(descriptor, layerName.c_str());
2847*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2848*89c4ff92SAndroid Build Coastguard Worker 
2849*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(transposeLayer->GetInputSlot(0));
2850*89c4ff92SAndroid Build Coastguard Worker     transposeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2851*89c4ff92SAndroid Build Coastguard Worker 
2852*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
2853*89c4ff92SAndroid Build Coastguard Worker     transposeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2854*89c4ff92SAndroid Build Coastguard Worker 
2855*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2856*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2857*89c4ff92SAndroid Build Coastguard Worker 
2858*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor<armnn::TransposeDescriptor> verifier(
2859*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputTensorInfo}, {outputTensorInfo}, descriptor);
2860*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2861*89c4ff92SAndroid Build Coastguard Worker }
2862*89c4ff92SAndroid Build Coastguard Worker 
2863*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeTransposeConvolution2d")
2864*89c4ff92SAndroid Build Coastguard Worker {
2865*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("transposeConvolution2d");
2866*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo ({ 1, 7, 7, 1 }, armnn::DataType::Float32);
2867*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo({ 1, 9, 9, 1 }, armnn::DataType::Float32);
2868*89c4ff92SAndroid Build Coastguard Worker 
2869*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo weightsInfo({ 1, 3, 3, 1 }, armnn::DataType::Float32, 0.0f, 0, true);
2870*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo biasesInfo ({ 1 }, armnn::DataType::Float32, 0.0f, 0, true);
2871*89c4ff92SAndroid Build Coastguard Worker 
2872*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsData = GenerateRandomData<float>(weightsInfo.GetNumElements());
2873*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightsInfo, weightsData);
2874*89c4ff92SAndroid Build Coastguard Worker 
2875*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasesData = GenerateRandomData<float>(biasesInfo.GetNumElements());
2876*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor biases(biasesInfo, biasesData);
2877*89c4ff92SAndroid Build Coastguard Worker 
2878*89c4ff92SAndroid Build Coastguard Worker     armnn::TransposeConvolution2dDescriptor descriptor;
2879*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadLeft     = 1;
2880*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadRight    = 1;
2881*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadTop      = 1;
2882*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PadBottom   = 1;
2883*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX     = 1;
2884*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY     = 1;
2885*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BiasEnabled = true;
2886*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout  = armnn::DataLayout::NHWC;
2887*89c4ff92SAndroid Build Coastguard Worker 
2888*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
2889*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer  = network->AddInputLayer(0);
2890*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const convLayer   =
2891*89c4ff92SAndroid Build Coastguard Worker             network->AddTransposeConvolution2dLayer(descriptor,
2892*89c4ff92SAndroid Build Coastguard Worker                                                     weights,
2893*89c4ff92SAndroid Build Coastguard Worker                                                     armnn::Optional<armnn::ConstTensor>(biases),
2894*89c4ff92SAndroid Build Coastguard Worker                                                     layerName.c_str());
2895*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
2896*89c4ff92SAndroid Build Coastguard Worker 
2897*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
2898*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
2899*89c4ff92SAndroid Build Coastguard Worker 
2900*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
2901*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
2902*89c4ff92SAndroid Build Coastguard Worker 
2903*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2904*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2905*89c4ff92SAndroid Build Coastguard Worker 
2906*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::ConstTensor> constants {weights, biases};
2907*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptorAndConstants<armnn::TransposeConvolution2dDescriptor> verifier(
2908*89c4ff92SAndroid Build Coastguard Worker             layerName, {inputInfo}, {outputInfo}, descriptor, constants);
2909*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2910*89c4ff92SAndroid Build Coastguard Worker }
2911*89c4ff92SAndroid Build Coastguard Worker 
2912*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeDeserializeNonLinearNetwork")
2913*89c4ff92SAndroid Build Coastguard Worker {
2914*89c4ff92SAndroid Build Coastguard Worker     class ConstantLayerVerifier : public LayerVerifierBase
2915*89c4ff92SAndroid Build Coastguard Worker     {
2916*89c4ff92SAndroid Build Coastguard Worker     public:
ConstantLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const armnn::ConstTensor & layerInput)2917*89c4ff92SAndroid Build Coastguard Worker         ConstantLayerVerifier(const std::string& layerName,
2918*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<armnn::TensorInfo>& inputInfos,
2919*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<armnn::TensorInfo>& outputInfos,
2920*89c4ff92SAndroid Build Coastguard Worker                               const armnn::ConstTensor& layerInput)
2921*89c4ff92SAndroid Build Coastguard Worker             : LayerVerifierBase(layerName, inputInfos, outputInfos)
2922*89c4ff92SAndroid Build Coastguard Worker             , m_LayerInput(layerInput) {}
2923*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)2924*89c4ff92SAndroid Build Coastguard Worker         void ExecuteStrategy(const armnn::IConnectableLayer* layer,
2925*89c4ff92SAndroid Build Coastguard Worker                              const armnn::BaseDescriptor& descriptor,
2926*89c4ff92SAndroid Build Coastguard Worker                              const std::vector<armnn::ConstTensor>& constants,
2927*89c4ff92SAndroid Build Coastguard Worker                              const char* name,
2928*89c4ff92SAndroid Build Coastguard Worker                              const armnn::LayerBindingId id = 0) override
2929*89c4ff92SAndroid Build Coastguard Worker         {
2930*89c4ff92SAndroid Build Coastguard Worker             armnn::IgnoreUnused(descriptor, constants, id);
2931*89c4ff92SAndroid Build Coastguard Worker             switch (layer->GetType())
2932*89c4ff92SAndroid Build Coastguard Worker             {
2933*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Input: break;
2934*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Output: break;
2935*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Addition: break;
2936*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Constant:
2937*89c4ff92SAndroid Build Coastguard Worker                 {
2938*89c4ff92SAndroid Build Coastguard Worker                     VerifyNameAndConnections(layer, name);
2939*89c4ff92SAndroid Build Coastguard Worker                     CompareConstTensor(constants.at(0), m_LayerInput);
2940*89c4ff92SAndroid Build Coastguard Worker                     break;
2941*89c4ff92SAndroid Build Coastguard Worker                 }
2942*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::ElementwiseBinary: break;
2943*89c4ff92SAndroid Build Coastguard Worker                 default:
2944*89c4ff92SAndroid Build Coastguard Worker                 {
2945*89c4ff92SAndroid Build Coastguard Worker                     throw armnn::Exception("Unexpected layer type in test model");
2946*89c4ff92SAndroid Build Coastguard Worker                 }
2947*89c4ff92SAndroid Build Coastguard Worker             }
2948*89c4ff92SAndroid Build Coastguard Worker         }
2949*89c4ff92SAndroid Build Coastguard Worker 
2950*89c4ff92SAndroid Build Coastguard Worker     private:
2951*89c4ff92SAndroid Build Coastguard Worker         armnn::ConstTensor m_LayerInput;
2952*89c4ff92SAndroid Build Coastguard Worker     };
2953*89c4ff92SAndroid Build Coastguard Worker 
2954*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("constant");
2955*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 2, 3 }, armnn::DataType::Float32, 0.0f, 0, true);
2956*89c4ff92SAndroid Build Coastguard Worker 
2957*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> constantData = GenerateRandomData<float>(info.GetNumElements());
2958*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor constTensor(info, constantData);
2959*89c4ff92SAndroid Build Coastguard Worker 
2960*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
2961*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* input = network->AddInputLayer(0);
2962*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
2963*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* add = network->AddAdditionLayer();
2964*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
2965*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* constant = network->AddConstantLayer(constTensor, layerName.c_str());
2966*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* output = network->AddOutputLayer(0);
2967*89c4ff92SAndroid Build Coastguard Worker 
2968*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(add->GetInputSlot(0));
2969*89c4ff92SAndroid Build Coastguard Worker     constant->GetOutputSlot(0).Connect(add->GetInputSlot(1));
2970*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).Connect(output->GetInputSlot(0));
2971*89c4ff92SAndroid Build Coastguard Worker 
2972*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(info);
2973*89c4ff92SAndroid Build Coastguard Worker     constant->GetOutputSlot(0).SetTensorInfo(info);
2974*89c4ff92SAndroid Build Coastguard Worker     add->GetOutputSlot(0).SetTensorInfo(info);
2975*89c4ff92SAndroid Build Coastguard Worker 
2976*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
2977*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
2978*89c4ff92SAndroid Build Coastguard Worker 
2979*89c4ff92SAndroid Build Coastguard Worker     ConstantLayerVerifier verifier(layerName, {}, {info}, constTensor);
2980*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
2981*89c4ff92SAndroid Build Coastguard Worker }
2982*89c4ff92SAndroid Build Coastguard Worker 
2983*89c4ff92SAndroid Build Coastguard Worker }