xref: /aosp_15_r20/external/armnn/delegate/test/ControlTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020,2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ControlTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker // CONCATENATION Operator
ConcatUint8TwoInputsTest(std::vector<armnn::BackendId> & backends)19*89c4ff92SAndroid Build Coastguard Worker void ConcatUint8TwoInputsTest(std::vector<armnn::BackendId>& backends)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 2, 2 };
22*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 4, 2 };
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     // Set input and output data
25*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<uint8_t>> inputValues;
26*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> inputValue1 { 0, 1, 2, 3 }; // Lower bounds
27*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> inputValue2 { 252, 253, 254, 255 }; // Upper bounds
28*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue1);
29*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue2);
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> expectedOutputValues { 0, 1, 2, 3, 252, 253, 254, 255 };
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     ConcatenationTest<uint8_t>(tflite::BuiltinOperator_CONCATENATION,
34*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_UINT8,
35*89c4ff92SAndroid Build Coastguard Worker                                backends,
36*89c4ff92SAndroid Build Coastguard Worker                                inputShape,
37*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputShape,
38*89c4ff92SAndroid Build Coastguard Worker                                inputValues,
39*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues);
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker 
ConcatInt16TwoInputsTest(std::vector<armnn::BackendId> & backends)42*89c4ff92SAndroid Build Coastguard Worker void ConcatInt16TwoInputsTest(std::vector<armnn::BackendId>& backends)
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 2, 2 };
45*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 4, 2 };
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<int16_t>> inputValues;
48*89c4ff92SAndroid Build Coastguard Worker     std::vector<int16_t> inputValue1 { -32768, -16384, -1, 0 };
49*89c4ff92SAndroid Build Coastguard Worker     std::vector<int16_t> inputValue2 { 1, 2, 16384, 32767 };
50*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue1);
51*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue2);
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     std::vector<int16_t> expectedOutputValues { -32768, -16384, -1, 0, 1, 2, 16384, 32767};
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     ConcatenationTest<int16_t>(tflite::BuiltinOperator_CONCATENATION,
56*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_INT16,
57*89c4ff92SAndroid Build Coastguard Worker                                backends,
58*89c4ff92SAndroid Build Coastguard Worker                                inputShape,
59*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputShape,
60*89c4ff92SAndroid Build Coastguard Worker                                inputValues,
61*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues);
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker 
ConcatFloat32TwoInputsTest(std::vector<armnn::BackendId> & backends)64*89c4ff92SAndroid Build Coastguard Worker void ConcatFloat32TwoInputsTest(std::vector<armnn::BackendId>& backends)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 2, 2 };
67*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 4, 2 };
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<float>> inputValues;
70*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValue1 { -127.f, -126.f, -1.f, 0.f };
71*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValue2 { 1.f, 2.f, 126.f, 127.f };
72*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue1);
73*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue2);
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues { -127.f, -126.f, -1.f, 0.f, 1.f, 2.f, 126.f, 127.f };
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     ConcatenationTest<float>(tflite::BuiltinOperator_CONCATENATION,
78*89c4ff92SAndroid Build Coastguard Worker                              ::tflite::TensorType_FLOAT32,
79*89c4ff92SAndroid Build Coastguard Worker                              backends,
80*89c4ff92SAndroid Build Coastguard Worker                              inputShape,
81*89c4ff92SAndroid Build Coastguard Worker                              expectedOutputShape,
82*89c4ff92SAndroid Build Coastguard Worker                              inputValues,
83*89c4ff92SAndroid Build Coastguard Worker                              expectedOutputValues);
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker 
ConcatThreeInputsTest(std::vector<armnn::BackendId> & backends)86*89c4ff92SAndroid Build Coastguard Worker void ConcatThreeInputsTest(std::vector<armnn::BackendId>& backends)
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 2, 2 };
89*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 6, 2 };
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<uint8_t>> inputValues;
92*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> inputValue1 { 0, 1, 2, 3 };
93*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> inputValue2 { 125, 126, 127, 128 };
94*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> inputValue3 { 252, 253, 254, 255 };
95*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue1);
96*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue2);
97*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue3);
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> expectedOutputValues { 0, 1, 2, 3, 125, 126, 127, 128, 252, 253, 254, 255 };
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker     ConcatenationTest<uint8_t>(tflite::BuiltinOperator_CONCATENATION,
102*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_UINT8,
103*89c4ff92SAndroid Build Coastguard Worker                                backends,
104*89c4ff92SAndroid Build Coastguard Worker                                inputShape,
105*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputShape,
106*89c4ff92SAndroid Build Coastguard Worker                                inputValues,
107*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues);
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker 
ConcatAxisTest(std::vector<armnn::BackendId> & backends)110*89c4ff92SAndroid Build Coastguard Worker void ConcatAxisTest(std::vector<armnn::BackendId>& backends)
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 1, 2, 2 };
113*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 1, 2, 4 };
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<uint8_t>> inputValues;
116*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> inputValue1 { 0, 1, 2, 3 };
117*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> inputValue3 { 252, 253, 254, 255 };
118*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue1);
119*89c4ff92SAndroid Build Coastguard Worker     inputValues.push_back(inputValue3);
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> expectedOutputValues { 0, 1, 252, 253, 2, 3, 254, 255 };
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     ConcatenationTest<uint8_t>(tflite::BuiltinOperator_CONCATENATION,
124*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_UINT8,
125*89c4ff92SAndroid Build Coastguard Worker                                backends,
126*89c4ff92SAndroid Build Coastguard Worker                                inputShape,
127*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputShape,
128*89c4ff92SAndroid Build Coastguard Worker                                inputValues,
129*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
130*89c4ff92SAndroid Build Coastguard Worker                                2);
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker // MEAN Operator
MeanUint8KeepDimsTest(std::vector<armnn::BackendId> & backends)134*89c4ff92SAndroid Build Coastguard Worker void MeanUint8KeepDimsTest(std::vector<armnn::BackendId>& backends)
135*89c4ff92SAndroid Build Coastguard Worker {
136*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input0Shape { 1, 3 };
137*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input1Shape { 1 };
138*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 1, 1 };
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> input0Values { 5, 10, 15 }; // Inputs
141*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input1Values { 1 }; // Axis
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> expectedOutputValues { 10 };
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker     MeanTest<uint8_t>(tflite::BuiltinOperator_MEAN,
146*89c4ff92SAndroid Build Coastguard Worker                       ::tflite::TensorType_UINT8,
147*89c4ff92SAndroid Build Coastguard Worker                       backends,
148*89c4ff92SAndroid Build Coastguard Worker                       input0Shape,
149*89c4ff92SAndroid Build Coastguard Worker                       input1Shape,
150*89c4ff92SAndroid Build Coastguard Worker                       expectedOutputShape,
151*89c4ff92SAndroid Build Coastguard Worker                       input0Values,
152*89c4ff92SAndroid Build Coastguard Worker                       input1Values,
153*89c4ff92SAndroid Build Coastguard Worker                       expectedOutputValues,
154*89c4ff92SAndroid Build Coastguard Worker                       true);
155*89c4ff92SAndroid Build Coastguard Worker }
156*89c4ff92SAndroid Build Coastguard Worker 
MeanUint8Test(std::vector<armnn::BackendId> & backends)157*89c4ff92SAndroid Build Coastguard Worker void MeanUint8Test(std::vector<armnn::BackendId>& backends)
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input0Shape { 1, 2, 2 };
160*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input1Shape { 1 };
161*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 2, 2 };
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> input0Values { 5, 10, 15, 20 }; // Inputs
164*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input1Values { 0 }; // Axis
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> expectedOutputValues { 5, 10, 15, 20 };
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker     MeanTest<uint8_t>(tflite::BuiltinOperator_MEAN,
169*89c4ff92SAndroid Build Coastguard Worker                       ::tflite::TensorType_UINT8,
170*89c4ff92SAndroid Build Coastguard Worker                       backends,
171*89c4ff92SAndroid Build Coastguard Worker                       input0Shape,
172*89c4ff92SAndroid Build Coastguard Worker                       input1Shape,
173*89c4ff92SAndroid Build Coastguard Worker                       expectedOutputShape,
174*89c4ff92SAndroid Build Coastguard Worker                       input0Values,
175*89c4ff92SAndroid Build Coastguard Worker                       input1Values,
176*89c4ff92SAndroid Build Coastguard Worker                       expectedOutputValues,
177*89c4ff92SAndroid Build Coastguard Worker                       false);
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker 
MeanFp32KeepDimsTest(std::vector<armnn::BackendId> & backends)180*89c4ff92SAndroid Build Coastguard Worker void MeanFp32KeepDimsTest(std::vector<armnn::BackendId>& backends)
181*89c4ff92SAndroid Build Coastguard Worker {
182*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input0Shape { 1, 2, 2 };
183*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input1Shape { 1 };
184*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 1, 1, 2 };
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   input0Values { 1.0f, 1.5f, 2.0f, 2.5f }; // Inputs
187*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input1Values { 1 }; // Axis
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   expectedOutputValues { 1.5f, 2.0f };
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker     MeanTest<float>(tflite::BuiltinOperator_MEAN,
192*89c4ff92SAndroid Build Coastguard Worker                     ::tflite::TensorType_FLOAT32,
193*89c4ff92SAndroid Build Coastguard Worker                     backends,
194*89c4ff92SAndroid Build Coastguard Worker                     input0Shape,
195*89c4ff92SAndroid Build Coastguard Worker                     input1Shape,
196*89c4ff92SAndroid Build Coastguard Worker                     expectedOutputShape,
197*89c4ff92SAndroid Build Coastguard Worker                     input0Values,
198*89c4ff92SAndroid Build Coastguard Worker                     input1Values,
199*89c4ff92SAndroid Build Coastguard Worker                     expectedOutputValues,
200*89c4ff92SAndroid Build Coastguard Worker                     true);
201*89c4ff92SAndroid Build Coastguard Worker }
202*89c4ff92SAndroid Build Coastguard Worker 
MeanFp32Test(std::vector<armnn::BackendId> & backends)203*89c4ff92SAndroid Build Coastguard Worker void MeanFp32Test(std::vector<armnn::BackendId>& backends)
204*89c4ff92SAndroid Build Coastguard Worker {
205*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input0Shape { 1, 2, 2, 1 };
206*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input1Shape { 1 };
207*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputShape { 1, 2, 1 };
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   input0Values { 1.0f, 1.5f, 2.0f, 2.5f }; // Inputs
210*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> input1Values { 2 }; // Axis
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker     std::vector<float>   expectedOutputValues { 1.25f, 2.25f };
213*89c4ff92SAndroid Build Coastguard Worker 
214*89c4ff92SAndroid Build Coastguard Worker     MeanTest<float>(tflite::BuiltinOperator_MEAN,
215*89c4ff92SAndroid Build Coastguard Worker                     ::tflite::TensorType_FLOAT32,
216*89c4ff92SAndroid Build Coastguard Worker                     backends,
217*89c4ff92SAndroid Build Coastguard Worker                     input0Shape,
218*89c4ff92SAndroid Build Coastguard Worker                     input1Shape,
219*89c4ff92SAndroid Build Coastguard Worker                     expectedOutputShape,
220*89c4ff92SAndroid Build Coastguard Worker                     input0Values,
221*89c4ff92SAndroid Build Coastguard Worker                     input1Values,
222*89c4ff92SAndroid Build Coastguard Worker                     expectedOutputValues,
223*89c4ff92SAndroid Build Coastguard Worker                     false);
224*89c4ff92SAndroid Build Coastguard Worker }
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker // CONCATENATION Tests.
227*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Concatenation_CpuAccTests")
228*89c4ff92SAndroid Build Coastguard Worker {
229*89c4ff92SAndroid Build Coastguard Worker 
230*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Uint8_CpuAcc_Test")
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
233*89c4ff92SAndroid Build Coastguard Worker     ConcatUint8TwoInputsTest(backends);
234*89c4ff92SAndroid Build Coastguard Worker }
235*89c4ff92SAndroid Build Coastguard Worker 
236*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Int16_CpuAcc_Test")
237*89c4ff92SAndroid Build Coastguard Worker {
238*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
239*89c4ff92SAndroid Build Coastguard Worker     ConcatInt16TwoInputsTest(backends);
240*89c4ff92SAndroid Build Coastguard Worker }
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Float32_CpuAcc_Test")
243*89c4ff92SAndroid Build Coastguard Worker {
244*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
245*89c4ff92SAndroid Build Coastguard Worker     ConcatFloat32TwoInputsTest(backends);
246*89c4ff92SAndroid Build Coastguard Worker }
247*89c4ff92SAndroid Build Coastguard Worker 
248*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Three_Inputs_CpuAcc_Test")
249*89c4ff92SAndroid Build Coastguard Worker {
250*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
251*89c4ff92SAndroid Build Coastguard Worker     ConcatThreeInputsTest(backends);
252*89c4ff92SAndroid Build Coastguard Worker }
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Axis_CpuAcc_Test")
255*89c4ff92SAndroid Build Coastguard Worker {
256*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
257*89c4ff92SAndroid Build Coastguard Worker     ConcatAxisTest(backends);
258*89c4ff92SAndroid Build Coastguard Worker }
259*89c4ff92SAndroid Build Coastguard Worker 
260*89c4ff92SAndroid Build Coastguard Worker }
261*89c4ff92SAndroid Build Coastguard Worker 
262*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Concatenation_GpuAccTests")
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker 
265*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Uint8_GpuAcc_Test")
266*89c4ff92SAndroid Build Coastguard Worker {
267*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
268*89c4ff92SAndroid Build Coastguard Worker     ConcatUint8TwoInputsTest(backends);
269*89c4ff92SAndroid Build Coastguard Worker }
270*89c4ff92SAndroid Build Coastguard Worker 
271*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Int16_GpuAcc_Test")
272*89c4ff92SAndroid Build Coastguard Worker {
273*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
274*89c4ff92SAndroid Build Coastguard Worker     ConcatInt16TwoInputsTest(backends);
275*89c4ff92SAndroid Build Coastguard Worker }
276*89c4ff92SAndroid Build Coastguard Worker 
277*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Float32_GpuAcc_Test")
278*89c4ff92SAndroid Build Coastguard Worker {
279*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
280*89c4ff92SAndroid Build Coastguard Worker     ConcatFloat32TwoInputsTest(backends);
281*89c4ff92SAndroid Build Coastguard Worker }
282*89c4ff92SAndroid Build Coastguard Worker 
283*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Three_Inputs_GpuAcc_Test")
284*89c4ff92SAndroid Build Coastguard Worker {
285*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
286*89c4ff92SAndroid Build Coastguard Worker     ConcatThreeInputsTest(backends);
287*89c4ff92SAndroid Build Coastguard Worker }
288*89c4ff92SAndroid Build Coastguard Worker 
289*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Axis_GpuAcc_Test")
290*89c4ff92SAndroid Build Coastguard Worker {
291*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
292*89c4ff92SAndroid Build Coastguard Worker     ConcatAxisTest(backends);
293*89c4ff92SAndroid Build Coastguard Worker }
294*89c4ff92SAndroid Build Coastguard Worker 
295*89c4ff92SAndroid Build Coastguard Worker }
296*89c4ff92SAndroid Build Coastguard Worker 
297*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Concatenation_CpuRefTests")
298*89c4ff92SAndroid Build Coastguard Worker {
299*89c4ff92SAndroid Build Coastguard Worker 
300*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Uint8_CpuRef_Test")
301*89c4ff92SAndroid Build Coastguard Worker {
302*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
303*89c4ff92SAndroid Build Coastguard Worker     ConcatUint8TwoInputsTest(backends);
304*89c4ff92SAndroid Build Coastguard Worker }
305*89c4ff92SAndroid Build Coastguard Worker 
306*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Int16_CpuRef_Test")
307*89c4ff92SAndroid Build Coastguard Worker {
308*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
309*89c4ff92SAndroid Build Coastguard Worker     ConcatInt16TwoInputsTest(backends);
310*89c4ff92SAndroid Build Coastguard Worker }
311*89c4ff92SAndroid Build Coastguard Worker 
312*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Float32_CpuRef_Test")
313*89c4ff92SAndroid Build Coastguard Worker {
314*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
315*89c4ff92SAndroid Build Coastguard Worker     ConcatFloat32TwoInputsTest(backends);
316*89c4ff92SAndroid Build Coastguard Worker }
317*89c4ff92SAndroid Build Coastguard Worker 
318*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Three_Inputs_CpuRef_Test")
319*89c4ff92SAndroid Build Coastguard Worker {
320*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
321*89c4ff92SAndroid Build Coastguard Worker     ConcatThreeInputsTest(backends);
322*89c4ff92SAndroid Build Coastguard Worker }
323*89c4ff92SAndroid Build Coastguard Worker 
324*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Concatenation_Axis_CpuRef_Test")
325*89c4ff92SAndroid Build Coastguard Worker {
326*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
327*89c4ff92SAndroid Build Coastguard Worker     ConcatAxisTest(backends);
328*89c4ff92SAndroid Build Coastguard Worker }
329*89c4ff92SAndroid Build Coastguard Worker 
330*89c4ff92SAndroid Build Coastguard Worker }
331*89c4ff92SAndroid Build Coastguard Worker 
332*89c4ff92SAndroid Build Coastguard Worker // MEAN Tests
333*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Mean_CpuAccTests")
334*89c4ff92SAndroid Build Coastguard Worker {
335*89c4ff92SAndroid Build Coastguard Worker 
336*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Uint8_KeepDims_CpuAcc_Test")
337*89c4ff92SAndroid Build Coastguard Worker {
338*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
339*89c4ff92SAndroid Build Coastguard Worker     MeanUint8KeepDimsTest(backends);
340*89c4ff92SAndroid Build Coastguard Worker }
341*89c4ff92SAndroid Build Coastguard Worker 
342*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Uint8_CpuAcc_Test")
343*89c4ff92SAndroid Build Coastguard Worker {
344*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
345*89c4ff92SAndroid Build Coastguard Worker     MeanUint8Test(backends);
346*89c4ff92SAndroid Build Coastguard Worker }
347*89c4ff92SAndroid Build Coastguard Worker 
348*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Fp32_KeepDims_CpuAcc_Test")
349*89c4ff92SAndroid Build Coastguard Worker {
350*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
351*89c4ff92SAndroid Build Coastguard Worker     MeanFp32KeepDimsTest(backends);
352*89c4ff92SAndroid Build Coastguard Worker }
353*89c4ff92SAndroid Build Coastguard Worker 
354*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Fp32_CpuAcc_Test")
355*89c4ff92SAndroid Build Coastguard Worker {
356*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
357*89c4ff92SAndroid Build Coastguard Worker     MeanFp32Test(backends);
358*89c4ff92SAndroid Build Coastguard Worker }
359*89c4ff92SAndroid Build Coastguard Worker 
360*89c4ff92SAndroid Build Coastguard Worker }
361*89c4ff92SAndroid Build Coastguard Worker 
362*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Mean_GpuAccTests")
363*89c4ff92SAndroid Build Coastguard Worker {
364*89c4ff92SAndroid Build Coastguard Worker 
365*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Uint8_KeepDims_GpuAcc_Test")
366*89c4ff92SAndroid Build Coastguard Worker {
367*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
368*89c4ff92SAndroid Build Coastguard Worker     MeanUint8KeepDimsTest(backends);
369*89c4ff92SAndroid Build Coastguard Worker }
370*89c4ff92SAndroid Build Coastguard Worker 
371*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Uint8_GpuAcc_Test")
372*89c4ff92SAndroid Build Coastguard Worker {
373*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
374*89c4ff92SAndroid Build Coastguard Worker     MeanUint8Test(backends);
375*89c4ff92SAndroid Build Coastguard Worker }
376*89c4ff92SAndroid Build Coastguard Worker 
377*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Fp32_KeepDims_GpuAcc_Test")
378*89c4ff92SAndroid Build Coastguard Worker {
379*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
380*89c4ff92SAndroid Build Coastguard Worker     MeanFp32KeepDimsTest(backends);
381*89c4ff92SAndroid Build Coastguard Worker }
382*89c4ff92SAndroid Build Coastguard Worker 
383*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Fp32_GpuAcc_Test")
384*89c4ff92SAndroid Build Coastguard Worker {
385*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
386*89c4ff92SAndroid Build Coastguard Worker     MeanFp32Test(backends);
387*89c4ff92SAndroid Build Coastguard Worker }
388*89c4ff92SAndroid Build Coastguard Worker 
389*89c4ff92SAndroid Build Coastguard Worker }
390*89c4ff92SAndroid Build Coastguard Worker 
391*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Mean_CpuRefTests")
392*89c4ff92SAndroid Build Coastguard Worker {
393*89c4ff92SAndroid Build Coastguard Worker 
394*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Uint8_KeepDims_CpuRef_Test")
395*89c4ff92SAndroid Build Coastguard Worker {
396*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
397*89c4ff92SAndroid Build Coastguard Worker     MeanUint8KeepDimsTest(backends);
398*89c4ff92SAndroid Build Coastguard Worker }
399*89c4ff92SAndroid Build Coastguard Worker 
400*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Uint8_CpuRef_Test")
401*89c4ff92SAndroid Build Coastguard Worker {
402*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
403*89c4ff92SAndroid Build Coastguard Worker     MeanUint8Test(backends);
404*89c4ff92SAndroid Build Coastguard Worker }
405*89c4ff92SAndroid Build Coastguard Worker 
406*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Fp32_KeepDims_CpuRef_Test")
407*89c4ff92SAndroid Build Coastguard Worker {
408*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
409*89c4ff92SAndroid Build Coastguard Worker     MeanFp32KeepDimsTest(backends);
410*89c4ff92SAndroid Build Coastguard Worker }
411*89c4ff92SAndroid Build Coastguard Worker 
412*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Mean_Fp32_CpuRef_Test")
413*89c4ff92SAndroid Build Coastguard Worker {
414*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
415*89c4ff92SAndroid Build Coastguard Worker     MeanFp32Test(backends);
416*89c4ff92SAndroid Build Coastguard Worker }
417*89c4ff92SAndroid Build Coastguard Worker 
418*89c4ff92SAndroid Build Coastguard Worker }
419*89c4ff92SAndroid Build Coastguard Worker 
420*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate