1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <vector>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker namespace
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker template<typename armnn::DataType DataType>
CreateSplitterNetwork(const TensorShape & inputShape,const std::vector<TensorShape> & outputShapes,unsigned int splitAxis,unsigned int numSplit,const float qScale=1.0f,const int32_t qOffset=0)23*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateSplitterNetwork(const TensorShape& inputShape,
24*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape>& outputShapes,
25*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis,
26*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit,
27*89c4ff92SAndroid Build Coastguard Worker const float qScale = 1.0f,
28*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = 0)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
31*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network.
32*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create());
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(inputShape, DataType, qScale, qOffset, true);
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> splitterDimSizes(inputShape.GetNumDimensions());
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker // Add current input shape to splitterDimSizes
39*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputShape.GetNumDimensions(); ++i)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker if (splitterDimSizes[splitAxis] % numSplit != 0)
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker throw ParseException("Number of splits must evenly divide the dimension");
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker splitterDimSizes[splitAxis] /= numSplit;
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker SplitterDescriptor splitDesc(numSplit, inputShape.GetNumDimensions());
51*89c4ff92SAndroid Build Coastguard Worker for (unsigned int g = 0; g < numSplit; ++g)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker // Set the size of the views.
54*89c4ff92SAndroid Build Coastguard Worker for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
55*89c4ff92SAndroid Build Coastguard Worker {
56*89c4ff92SAndroid Build Coastguard Worker splitDesc.SetViewSize(g, dimIdx, splitterDimSizes[dimIdx]);
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker splitDesc.SetViewOriginCoord(g, splitAxis, splitterDimSizes[splitAxis] * g);
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* splitter = net->AddSplitterLayer(splitDesc, "splitter");
62*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input = net->AddInputLayer(0, "input");
63*89c4ff92SAndroid Build Coastguard Worker Connect(input, splitter, inputTensorInfo, 0, 0);
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < outputShapes.size(); ++i)
66*89c4ff92SAndroid Build Coastguard Worker {
67*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(outputShapes[i], DataType, qScale, qOffset);
68*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(armnn::numeric_cast<LayerBindingId>(i));
69*89c4ff92SAndroid Build Coastguard Worker Connect(splitter, output, outputTensorInfo, i, 0);
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker return net;
73*89c4ff92SAndroid Build Coastguard Worker }
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
Splitter1dEndToEnd(const std::vector<BackendId> & backends)76*89c4ff92SAndroid Build Coastguard Worker void Splitter1dEndToEnd(const std::vector<BackendId>& backends)
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
79*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 0;
82*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 2;
83*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 4 };
84*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 2 }, { 2 }};
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
87*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
92*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{ 1, 2, 3, 4 };
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{ 1, 2 };
95*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{ 3, 4 };
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
98*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput0 }, {1, expectedOutput1} };
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
Splitter2dDim0EndToEnd(const std::vector<BackendId> & backends)104*89c4ff92SAndroid Build Coastguard Worker void Splitter2dDim0EndToEnd(const std::vector<BackendId>& backends)
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
107*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 0;
110*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 2;
111*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 4, 3 };
112*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 2, 3 }, { 2, 3 }};
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
115*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
120*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
121*89c4ff92SAndroid Build Coastguard Worker 1, 2,
122*89c4ff92SAndroid Build Coastguard Worker 3, 4,
123*89c4ff92SAndroid Build Coastguard Worker 5, 6,
124*89c4ff92SAndroid Build Coastguard Worker 7, 8,
125*89c4ff92SAndroid Build Coastguard Worker 9, 10,
126*89c4ff92SAndroid Build Coastguard Worker 11, 12
127*89c4ff92SAndroid Build Coastguard Worker };
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{ 1, 2, 3, 4, 5, 6 };
130*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{ 7, 8, 9, 10, 11, 12 };
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
133*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput0 }, {1, expectedOutput1} };
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
Splitter2dDim1EndToEnd(const std::vector<BackendId> & backends)139*89c4ff92SAndroid Build Coastguard Worker void Splitter2dDim1EndToEnd(const std::vector<BackendId>& backends)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
142*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
143*89c4ff92SAndroid Build Coastguard Worker
144*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 1;
145*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 3;
146*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 4, 3 };
147*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 4, 1 }, { 4, 1 }, { 4, 1 }};
148*89c4ff92SAndroid Build Coastguard Worker
149*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
150*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
153*89c4ff92SAndroid Build Coastguard Worker
154*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
155*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
156*89c4ff92SAndroid Build Coastguard Worker 1, 2,
157*89c4ff92SAndroid Build Coastguard Worker 3, 4,
158*89c4ff92SAndroid Build Coastguard Worker 5, 6,
159*89c4ff92SAndroid Build Coastguard Worker 7, 8,
160*89c4ff92SAndroid Build Coastguard Worker 9, 10,
161*89c4ff92SAndroid Build Coastguard Worker 11, 12
162*89c4ff92SAndroid Build Coastguard Worker };
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{ 1, 4, 7, 10 };
165*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{ 2, 5, 8, 11 };
166*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput2{ 3, 6, 9, 12 };
167*89c4ff92SAndroid Build Coastguard Worker
168*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
169*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput0 },
170*89c4ff92SAndroid Build Coastguard Worker { 1, expectedOutput1 },
171*89c4ff92SAndroid Build Coastguard Worker { 2, expectedOutput2 } };
172*89c4ff92SAndroid Build Coastguard Worker
173*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker
176*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
Splitter3dDim0EndToEnd(const std::vector<BackendId> & backends)177*89c4ff92SAndroid Build Coastguard Worker void Splitter3dDim0EndToEnd(const std::vector<BackendId>& backends)
178*89c4ff92SAndroid Build Coastguard Worker {
179*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
180*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
181*89c4ff92SAndroid Build Coastguard Worker
182*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 0;
183*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 2;
184*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 2, 4, 3 };
185*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 1, 4, 3 }, { 1, 4, 3 }};
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
188*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
189*89c4ff92SAndroid Build Coastguard Worker
190*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
191*89c4ff92SAndroid Build Coastguard Worker
192*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
193*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
194*89c4ff92SAndroid Build Coastguard Worker 1, 2, 3,
195*89c4ff92SAndroid Build Coastguard Worker 4, 5, 6,
196*89c4ff92SAndroid Build Coastguard Worker 7, 8, 9,
197*89c4ff92SAndroid Build Coastguard Worker 10, 11, 12,
198*89c4ff92SAndroid Build Coastguard Worker 13, 14, 15,
199*89c4ff92SAndroid Build Coastguard Worker 16, 17, 18,
200*89c4ff92SAndroid Build Coastguard Worker 19, 20, 21,
201*89c4ff92SAndroid Build Coastguard Worker 22, 23, 24
202*89c4ff92SAndroid Build Coastguard Worker };
203*89c4ff92SAndroid Build Coastguard Worker
204*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{
205*89c4ff92SAndroid Build Coastguard Worker 1, 2, 3,
206*89c4ff92SAndroid Build Coastguard Worker 4, 5, 6,
207*89c4ff92SAndroid Build Coastguard Worker 7, 8, 9,
208*89c4ff92SAndroid Build Coastguard Worker 10, 11, 12
209*89c4ff92SAndroid Build Coastguard Worker };
210*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{
211*89c4ff92SAndroid Build Coastguard Worker 13, 14, 15,
212*89c4ff92SAndroid Build Coastguard Worker 16, 17, 18,
213*89c4ff92SAndroid Build Coastguard Worker 19, 20, 21,
214*89c4ff92SAndroid Build Coastguard Worker 22, 23, 24
215*89c4ff92SAndroid Build Coastguard Worker };
216*89c4ff92SAndroid Build Coastguard Worker
217*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
218*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput0 },
219*89c4ff92SAndroid Build Coastguard Worker { 1, expectedOutput1 } };
220*89c4ff92SAndroid Build Coastguard Worker
221*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
222*89c4ff92SAndroid Build Coastguard Worker }
223*89c4ff92SAndroid Build Coastguard Worker
224*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
Splitter3dDim1EndToEnd(const std::vector<BackendId> & backends)225*89c4ff92SAndroid Build Coastguard Worker void Splitter3dDim1EndToEnd(const std::vector<BackendId>& backends)
226*89c4ff92SAndroid Build Coastguard Worker {
227*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
228*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
229*89c4ff92SAndroid Build Coastguard Worker
230*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 1;
231*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 2;
232*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 2, 4, 3 };
233*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 2, 2, 3 }, { 2, 2, 3 }};
234*89c4ff92SAndroid Build Coastguard Worker
235*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
236*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
237*89c4ff92SAndroid Build Coastguard Worker
238*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
239*89c4ff92SAndroid Build Coastguard Worker
240*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
241*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
242*89c4ff92SAndroid Build Coastguard Worker 1, 2, 3,
243*89c4ff92SAndroid Build Coastguard Worker 4, 5, 6,
244*89c4ff92SAndroid Build Coastguard Worker 7, 8, 9,
245*89c4ff92SAndroid Build Coastguard Worker 10, 11, 12,
246*89c4ff92SAndroid Build Coastguard Worker 13, 14, 15,
247*89c4ff92SAndroid Build Coastguard Worker 16, 17, 18,
248*89c4ff92SAndroid Build Coastguard Worker 19, 20, 21,
249*89c4ff92SAndroid Build Coastguard Worker 22, 23, 24
250*89c4ff92SAndroid Build Coastguard Worker };
251*89c4ff92SAndroid Build Coastguard Worker
252*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{
253*89c4ff92SAndroid Build Coastguard Worker 1, 2, 3,
254*89c4ff92SAndroid Build Coastguard Worker 4, 5, 6,
255*89c4ff92SAndroid Build Coastguard Worker 13, 14, 15,
256*89c4ff92SAndroid Build Coastguard Worker 16, 17, 18
257*89c4ff92SAndroid Build Coastguard Worker };
258*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{
259*89c4ff92SAndroid Build Coastguard Worker 7, 8, 9,
260*89c4ff92SAndroid Build Coastguard Worker 10, 11, 12,
261*89c4ff92SAndroid Build Coastguard Worker 19, 20, 21,
262*89c4ff92SAndroid Build Coastguard Worker 22, 23, 24
263*89c4ff92SAndroid Build Coastguard Worker };
264*89c4ff92SAndroid Build Coastguard Worker
265*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
266*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput0 },
267*89c4ff92SAndroid Build Coastguard Worker { 1, expectedOutput1 } };
268*89c4ff92SAndroid Build Coastguard Worker
269*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
270*89c4ff92SAndroid Build Coastguard Worker }
271*89c4ff92SAndroid Build Coastguard Worker
272*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
Splitter3dDim2EndToEnd(const std::vector<BackendId> & backends)273*89c4ff92SAndroid Build Coastguard Worker void Splitter3dDim2EndToEnd(const std::vector<BackendId>& backends)
274*89c4ff92SAndroid Build Coastguard Worker {
275*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
276*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
277*89c4ff92SAndroid Build Coastguard Worker
278*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 2;
279*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 3;
280*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 2, 4, 3 };
281*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 2, 4, 1 }, { 2, 4, 1 }, { 2, 4, 1 }};
282*89c4ff92SAndroid Build Coastguard Worker
283*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
284*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
285*89c4ff92SAndroid Build Coastguard Worker
286*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
287*89c4ff92SAndroid Build Coastguard Worker
288*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
289*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
290*89c4ff92SAndroid Build Coastguard Worker 1, 2, 3,
291*89c4ff92SAndroid Build Coastguard Worker 4, 5, 6,
292*89c4ff92SAndroid Build Coastguard Worker 7, 8, 9,
293*89c4ff92SAndroid Build Coastguard Worker 10, 11, 12,
294*89c4ff92SAndroid Build Coastguard Worker 13, 14, 15,
295*89c4ff92SAndroid Build Coastguard Worker 16, 17, 18,
296*89c4ff92SAndroid Build Coastguard Worker 19, 20, 21,
297*89c4ff92SAndroid Build Coastguard Worker 22, 23, 24
298*89c4ff92SAndroid Build Coastguard Worker };
299*89c4ff92SAndroid Build Coastguard Worker
300*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{ 1, 4, 7, 10, 13, 16, 19, 22 };
301*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{ 2, 5, 8, 11, 14, 17, 20, 23 };
302*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput2{ 3, 6, 9, 12, 15, 18, 21, 24 };
303*89c4ff92SAndroid Build Coastguard Worker
304*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = { { 0, inputData } };
305*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput0 },
306*89c4ff92SAndroid Build Coastguard Worker { 1, expectedOutput1 },
307*89c4ff92SAndroid Build Coastguard Worker { 2, expectedOutput2 } };
308*89c4ff92SAndroid Build Coastguard Worker
309*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
310*89c4ff92SAndroid Build Coastguard Worker }
311*89c4ff92SAndroid Build Coastguard Worker
312*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
Splitter4dDim0EndToEnd(const std::vector<BackendId> & backends)313*89c4ff92SAndroid Build Coastguard Worker void Splitter4dDim0EndToEnd(const std::vector<BackendId>& backends)
314*89c4ff92SAndroid Build Coastguard Worker {
315*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
316*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
317*89c4ff92SAndroid Build Coastguard Worker
318*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 0;
319*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 2;
320*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 4, 3, 2, 2 };
321*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
322*89c4ff92SAndroid Build Coastguard Worker
323*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
324*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
325*89c4ff92SAndroid Build Coastguard Worker
326*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
327*89c4ff92SAndroid Build Coastguard Worker
328*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
329*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
330*89c4ff92SAndroid Build Coastguard Worker 1, 2,
331*89c4ff92SAndroid Build Coastguard Worker 3, 4,
332*89c4ff92SAndroid Build Coastguard Worker 5, 6,
333*89c4ff92SAndroid Build Coastguard Worker 7, 8,
334*89c4ff92SAndroid Build Coastguard Worker 9, 10,
335*89c4ff92SAndroid Build Coastguard Worker 11, 12,
336*89c4ff92SAndroid Build Coastguard Worker 13, 14,
337*89c4ff92SAndroid Build Coastguard Worker 15, 16,
338*89c4ff92SAndroid Build Coastguard Worker 17, 18,
339*89c4ff92SAndroid Build Coastguard Worker 19, 20,
340*89c4ff92SAndroid Build Coastguard Worker 21, 22,
341*89c4ff92SAndroid Build Coastguard Worker 23, 24,
342*89c4ff92SAndroid Build Coastguard Worker 25, 26,
343*89c4ff92SAndroid Build Coastguard Worker 27, 28,
344*89c4ff92SAndroid Build Coastguard Worker 29, 30,
345*89c4ff92SAndroid Build Coastguard Worker 31, 32,
346*89c4ff92SAndroid Build Coastguard Worker 33, 34,
347*89c4ff92SAndroid Build Coastguard Worker 35, 36,
348*89c4ff92SAndroid Build Coastguard Worker 37, 38,
349*89c4ff92SAndroid Build Coastguard Worker 39, 40,
350*89c4ff92SAndroid Build Coastguard Worker 41, 42,
351*89c4ff92SAndroid Build Coastguard Worker 43, 44,
352*89c4ff92SAndroid Build Coastguard Worker 45, 46,
353*89c4ff92SAndroid Build Coastguard Worker 47, 48
354*89c4ff92SAndroid Build Coastguard Worker };
355*89c4ff92SAndroid Build Coastguard Worker
356*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{
357*89c4ff92SAndroid Build Coastguard Worker 1, 2,
358*89c4ff92SAndroid Build Coastguard Worker 3, 4,
359*89c4ff92SAndroid Build Coastguard Worker 5, 6,
360*89c4ff92SAndroid Build Coastguard Worker 7, 8,
361*89c4ff92SAndroid Build Coastguard Worker 9, 10,
362*89c4ff92SAndroid Build Coastguard Worker 11, 12,
363*89c4ff92SAndroid Build Coastguard Worker 13, 14,
364*89c4ff92SAndroid Build Coastguard Worker 15, 16,
365*89c4ff92SAndroid Build Coastguard Worker 17, 18,
366*89c4ff92SAndroid Build Coastguard Worker 19, 20,
367*89c4ff92SAndroid Build Coastguard Worker 21, 22,
368*89c4ff92SAndroid Build Coastguard Worker 23, 24
369*89c4ff92SAndroid Build Coastguard Worker };
370*89c4ff92SAndroid Build Coastguard Worker
371*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{
372*89c4ff92SAndroid Build Coastguard Worker 25, 26,
373*89c4ff92SAndroid Build Coastguard Worker 27, 28,
374*89c4ff92SAndroid Build Coastguard Worker 29, 30,
375*89c4ff92SAndroid Build Coastguard Worker 31, 32,
376*89c4ff92SAndroid Build Coastguard Worker 33, 34,
377*89c4ff92SAndroid Build Coastguard Worker 35, 36,
378*89c4ff92SAndroid Build Coastguard Worker 37, 38,
379*89c4ff92SAndroid Build Coastguard Worker 39, 40,
380*89c4ff92SAndroid Build Coastguard Worker 41, 42,
381*89c4ff92SAndroid Build Coastguard Worker 43, 44,
382*89c4ff92SAndroid Build Coastguard Worker 45, 46,
383*89c4ff92SAndroid Build Coastguard Worker 47, 48
384*89c4ff92SAndroid Build Coastguard Worker };
385*89c4ff92SAndroid Build Coastguard Worker
386*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }};
387*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput0 }, { 1, expectedOutput1 }};
388*89c4ff92SAndroid Build Coastguard Worker
389*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
390*89c4ff92SAndroid Build Coastguard Worker }
391*89c4ff92SAndroid Build Coastguard Worker
392*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
Splitter4dDim1EndToEnd(const std::vector<BackendId> & backends)393*89c4ff92SAndroid Build Coastguard Worker void Splitter4dDim1EndToEnd(const std::vector<BackendId>& backends)
394*89c4ff92SAndroid Build Coastguard Worker {
395*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
396*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
397*89c4ff92SAndroid Build Coastguard Worker
398*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 1;
399*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 2;
400*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 2, 6, 2, 2 };
401*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
402*89c4ff92SAndroid Build Coastguard Worker
403*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
404*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
405*89c4ff92SAndroid Build Coastguard Worker
406*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
407*89c4ff92SAndroid Build Coastguard Worker
408*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
409*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
410*89c4ff92SAndroid Build Coastguard Worker 1, 2,
411*89c4ff92SAndroid Build Coastguard Worker 3, 4,
412*89c4ff92SAndroid Build Coastguard Worker 5, 6,
413*89c4ff92SAndroid Build Coastguard Worker 7, 8,
414*89c4ff92SAndroid Build Coastguard Worker 9, 10,
415*89c4ff92SAndroid Build Coastguard Worker 11, 12,
416*89c4ff92SAndroid Build Coastguard Worker 13, 14,
417*89c4ff92SAndroid Build Coastguard Worker 15, 16,
418*89c4ff92SAndroid Build Coastguard Worker 17, 18,
419*89c4ff92SAndroid Build Coastguard Worker 19, 20,
420*89c4ff92SAndroid Build Coastguard Worker 21, 22,
421*89c4ff92SAndroid Build Coastguard Worker 23, 24,
422*89c4ff92SAndroid Build Coastguard Worker 25, 26,
423*89c4ff92SAndroid Build Coastguard Worker 27, 28,
424*89c4ff92SAndroid Build Coastguard Worker 29, 30,
425*89c4ff92SAndroid Build Coastguard Worker 31, 32,
426*89c4ff92SAndroid Build Coastguard Worker 33, 34,
427*89c4ff92SAndroid Build Coastguard Worker 35, 36,
428*89c4ff92SAndroid Build Coastguard Worker 37, 38,
429*89c4ff92SAndroid Build Coastguard Worker 39, 40,
430*89c4ff92SAndroid Build Coastguard Worker 41, 42,
431*89c4ff92SAndroid Build Coastguard Worker 43, 44,
432*89c4ff92SAndroid Build Coastguard Worker 45, 46,
433*89c4ff92SAndroid Build Coastguard Worker 47, 48
434*89c4ff92SAndroid Build Coastguard Worker };
435*89c4ff92SAndroid Build Coastguard Worker
436*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{
437*89c4ff92SAndroid Build Coastguard Worker 1, 2,
438*89c4ff92SAndroid Build Coastguard Worker 3, 4,
439*89c4ff92SAndroid Build Coastguard Worker 5, 6,
440*89c4ff92SAndroid Build Coastguard Worker 7, 8,
441*89c4ff92SAndroid Build Coastguard Worker 9, 10,
442*89c4ff92SAndroid Build Coastguard Worker 11, 12,
443*89c4ff92SAndroid Build Coastguard Worker 25, 26,
444*89c4ff92SAndroid Build Coastguard Worker 27, 28,
445*89c4ff92SAndroid Build Coastguard Worker 29, 30,
446*89c4ff92SAndroid Build Coastguard Worker 31, 32,
447*89c4ff92SAndroid Build Coastguard Worker 33, 34,
448*89c4ff92SAndroid Build Coastguard Worker 35, 36
449*89c4ff92SAndroid Build Coastguard Worker };
450*89c4ff92SAndroid Build Coastguard Worker
451*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{
452*89c4ff92SAndroid Build Coastguard Worker 13, 14,
453*89c4ff92SAndroid Build Coastguard Worker 15, 16,
454*89c4ff92SAndroid Build Coastguard Worker 17, 18,
455*89c4ff92SAndroid Build Coastguard Worker 19, 20,
456*89c4ff92SAndroid Build Coastguard Worker 21, 22,
457*89c4ff92SAndroid Build Coastguard Worker 23, 24,
458*89c4ff92SAndroid Build Coastguard Worker 37, 38,
459*89c4ff92SAndroid Build Coastguard Worker 39, 40,
460*89c4ff92SAndroid Build Coastguard Worker 41, 42,
461*89c4ff92SAndroid Build Coastguard Worker 43, 44,
462*89c4ff92SAndroid Build Coastguard Worker 45, 46,
463*89c4ff92SAndroid Build Coastguard Worker 47, 48
464*89c4ff92SAndroid Build Coastguard Worker };
465*89c4ff92SAndroid Build Coastguard Worker
466*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }};
467*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput0 }, { 1, expectedOutput1 }};
468*89c4ff92SAndroid Build Coastguard Worker
469*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
470*89c4ff92SAndroid Build Coastguard Worker }
471*89c4ff92SAndroid Build Coastguard Worker
472*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType>
Splitter4dDim2EndToEnd(const std::vector<BackendId> & backends)473*89c4ff92SAndroid Build Coastguard Worker void Splitter4dDim2EndToEnd(const std::vector<BackendId>& backends)
474*89c4ff92SAndroid Build Coastguard Worker {
475*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
476*89c4ff92SAndroid Build Coastguard Worker using T = ResolveType<ArmnnType>;
477*89c4ff92SAndroid Build Coastguard Worker
478*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 2;
479*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 2;
480*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 2, 3, 4, 2 };
481*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
482*89c4ff92SAndroid Build Coastguard Worker
483*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
484*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
485*89c4ff92SAndroid Build Coastguard Worker
486*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
487*89c4ff92SAndroid Build Coastguard Worker
488*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
489*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
490*89c4ff92SAndroid Build Coastguard Worker 1, 2,
491*89c4ff92SAndroid Build Coastguard Worker 3, 4,
492*89c4ff92SAndroid Build Coastguard Worker 5, 6,
493*89c4ff92SAndroid Build Coastguard Worker 7, 8,
494*89c4ff92SAndroid Build Coastguard Worker 9, 10,
495*89c4ff92SAndroid Build Coastguard Worker 11, 12,
496*89c4ff92SAndroid Build Coastguard Worker 13, 14,
497*89c4ff92SAndroid Build Coastguard Worker 15, 16,
498*89c4ff92SAndroid Build Coastguard Worker 17, 18,
499*89c4ff92SAndroid Build Coastguard Worker 19, 20,
500*89c4ff92SAndroid Build Coastguard Worker 21, 22,
501*89c4ff92SAndroid Build Coastguard Worker 23, 24,
502*89c4ff92SAndroid Build Coastguard Worker 25, 26,
503*89c4ff92SAndroid Build Coastguard Worker 27, 28,
504*89c4ff92SAndroid Build Coastguard Worker 29, 30,
505*89c4ff92SAndroid Build Coastguard Worker 31, 32,
506*89c4ff92SAndroid Build Coastguard Worker 33, 34,
507*89c4ff92SAndroid Build Coastguard Worker 35, 36,
508*89c4ff92SAndroid Build Coastguard Worker 37, 38,
509*89c4ff92SAndroid Build Coastguard Worker 39, 40,
510*89c4ff92SAndroid Build Coastguard Worker 41, 42,
511*89c4ff92SAndroid Build Coastguard Worker 43, 44,
512*89c4ff92SAndroid Build Coastguard Worker 45, 46,
513*89c4ff92SAndroid Build Coastguard Worker 47, 48
514*89c4ff92SAndroid Build Coastguard Worker };
515*89c4ff92SAndroid Build Coastguard Worker
516*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{
517*89c4ff92SAndroid Build Coastguard Worker 1, 2,
518*89c4ff92SAndroid Build Coastguard Worker 3, 4,
519*89c4ff92SAndroid Build Coastguard Worker 9, 10,
520*89c4ff92SAndroid Build Coastguard Worker 11, 12,
521*89c4ff92SAndroid Build Coastguard Worker 17, 18,
522*89c4ff92SAndroid Build Coastguard Worker 19, 20,
523*89c4ff92SAndroid Build Coastguard Worker 25, 26,
524*89c4ff92SAndroid Build Coastguard Worker 27, 28,
525*89c4ff92SAndroid Build Coastguard Worker 33, 34,
526*89c4ff92SAndroid Build Coastguard Worker 35, 36,
527*89c4ff92SAndroid Build Coastguard Worker 41, 42,
528*89c4ff92SAndroid Build Coastguard Worker 43, 44
529*89c4ff92SAndroid Build Coastguard Worker };
530*89c4ff92SAndroid Build Coastguard Worker
531*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{
532*89c4ff92SAndroid Build Coastguard Worker 5, 6,
533*89c4ff92SAndroid Build Coastguard Worker 7, 8,
534*89c4ff92SAndroid Build Coastguard Worker 13, 14,
535*89c4ff92SAndroid Build Coastguard Worker 15, 16,
536*89c4ff92SAndroid Build Coastguard Worker 21, 22,
537*89c4ff92SAndroid Build Coastguard Worker 23, 24,
538*89c4ff92SAndroid Build Coastguard Worker 29, 30,
539*89c4ff92SAndroid Build Coastguard Worker 31, 32,
540*89c4ff92SAndroid Build Coastguard Worker 37, 38,
541*89c4ff92SAndroid Build Coastguard Worker 39, 40,
542*89c4ff92SAndroid Build Coastguard Worker 45, 46,
543*89c4ff92SAndroid Build Coastguard Worker 47, 48
544*89c4ff92SAndroid Build Coastguard Worker };
545*89c4ff92SAndroid Build Coastguard Worker
546*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }};
547*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput0 }, { 1, expectedOutput1 }};
548*89c4ff92SAndroid Build Coastguard Worker
549*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
550*89c4ff92SAndroid Build Coastguard Worker }
551*89c4ff92SAndroid Build Coastguard Worker
552*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
Splitter4dDim3EndToEnd(const std::vector<BackendId> & backends)553*89c4ff92SAndroid Build Coastguard Worker void Splitter4dDim3EndToEnd(const std::vector<BackendId>& backends)
554*89c4ff92SAndroid Build Coastguard Worker {
555*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
556*89c4ff92SAndroid Build Coastguard Worker
557*89c4ff92SAndroid Build Coastguard Worker unsigned int splitAxis = 3;
558*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = 2;
559*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 2, 3, 4, 2 };
560*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorShape> outputShapes{{ 2, 3, 4, 1 }, { 2, 3, 4, 1 }};
561*89c4ff92SAndroid Build Coastguard Worker
562*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
563*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateSplitterNetwork<ArmnnType>(inputShape, outputShapes, splitAxis, numSplit);
564*89c4ff92SAndroid Build Coastguard Worker
565*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
566*89c4ff92SAndroid Build Coastguard Worker
567*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
568*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
569*89c4ff92SAndroid Build Coastguard Worker 1, 2,
570*89c4ff92SAndroid Build Coastguard Worker 3, 4,
571*89c4ff92SAndroid Build Coastguard Worker 5, 6,
572*89c4ff92SAndroid Build Coastguard Worker 7, 8,
573*89c4ff92SAndroid Build Coastguard Worker 9, 10,
574*89c4ff92SAndroid Build Coastguard Worker 11, 12,
575*89c4ff92SAndroid Build Coastguard Worker 13, 14,
576*89c4ff92SAndroid Build Coastguard Worker 15, 16,
577*89c4ff92SAndroid Build Coastguard Worker 17, 18,
578*89c4ff92SAndroid Build Coastguard Worker 19, 20,
579*89c4ff92SAndroid Build Coastguard Worker 21, 22,
580*89c4ff92SAndroid Build Coastguard Worker 23, 24,
581*89c4ff92SAndroid Build Coastguard Worker 25, 26,
582*89c4ff92SAndroid Build Coastguard Worker 27, 28,
583*89c4ff92SAndroid Build Coastguard Worker 29, 30,
584*89c4ff92SAndroid Build Coastguard Worker 31, 32,
585*89c4ff92SAndroid Build Coastguard Worker 33, 34,
586*89c4ff92SAndroid Build Coastguard Worker 35, 36,
587*89c4ff92SAndroid Build Coastguard Worker 37, 38,
588*89c4ff92SAndroid Build Coastguard Worker 39, 40,
589*89c4ff92SAndroid Build Coastguard Worker 41, 42,
590*89c4ff92SAndroid Build Coastguard Worker 43, 44,
591*89c4ff92SAndroid Build Coastguard Worker 45, 46,
592*89c4ff92SAndroid Build Coastguard Worker 47, 48
593*89c4ff92SAndroid Build Coastguard Worker };
594*89c4ff92SAndroid Build Coastguard Worker
595*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput0{
596*89c4ff92SAndroid Build Coastguard Worker 1, 3, 5, 7,
597*89c4ff92SAndroid Build Coastguard Worker 9, 11, 13, 15,
598*89c4ff92SAndroid Build Coastguard Worker 17, 19, 21, 23,
599*89c4ff92SAndroid Build Coastguard Worker 25, 27, 29, 31,
600*89c4ff92SAndroid Build Coastguard Worker 33, 35, 37, 39,
601*89c4ff92SAndroid Build Coastguard Worker 41, 43, 45, 47
602*89c4ff92SAndroid Build Coastguard Worker };
603*89c4ff92SAndroid Build Coastguard Worker
604*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput1{
605*89c4ff92SAndroid Build Coastguard Worker 2, 4, 6, 8,
606*89c4ff92SAndroid Build Coastguard Worker 10, 12, 14, 16,
607*89c4ff92SAndroid Build Coastguard Worker 18, 20, 22, 24,
608*89c4ff92SAndroid Build Coastguard Worker 26, 28, 30, 32,
609*89c4ff92SAndroid Build Coastguard Worker 34, 36, 38, 40,
610*89c4ff92SAndroid Build Coastguard Worker 42, 44, 46, 48
611*89c4ff92SAndroid Build Coastguard Worker };
612*89c4ff92SAndroid Build Coastguard Worker
613*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{ 0,inputData }};
614*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput0 }, { 1, expectedOutput1 }};
615*89c4ff92SAndroid Build Coastguard Worker
616*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
617*89c4ff92SAndroid Build Coastguard Worker }
618*89c4ff92SAndroid Build Coastguard Worker
619*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
620