xref: /aosp_15_r20/external/armnn/src/armnnUtils/test/TensorUtilsTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019,2021-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TensorUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
13*89c4ff92SAndroid Build Coastguard Worker using namespace armnnUtils;
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorUtilsSuite")
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsAxis0Test")
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker     // Expand dimension 0
22*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDims(inputShape, 0);
23*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
24*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 1);
25*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 2);
26*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 3);
27*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 4);
28*89c4ff92SAndroid Build Coastguard Worker }
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsAxis1Test")
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     // Expand dimension 1
35*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDims(inputShape, 1);
36*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
37*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 2);
38*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 1);
39*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 3);
40*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 4);
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsAxis2Test")
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     // Expand dimension 2
48*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDims(inputShape, 2);
49*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
50*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 2);
51*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 3);
52*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 1);
53*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 4);
54*89c4ff92SAndroid Build Coastguard Worker }
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsAxis3Test")
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker     // Expand dimension 3
61*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDims(inputShape, 3);
62*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
63*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 2);
64*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 3);
65*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 4);
66*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 1);
67*89c4ff92SAndroid Build Coastguard Worker }
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsNegativeAxis1Test")
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker     // Expand dimension -1
74*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDims(inputShape, -1);
75*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
76*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 2);
77*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 3);
78*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 4);
79*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 1);
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsNegativeAxis2Test")
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     // Expand dimension -2
87*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDims(inputShape, -2);
88*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
89*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 2);
90*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 3);
91*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 1);
92*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 4);
93*89c4ff92SAndroid Build Coastguard Worker }
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsNegativeAxis3Test")
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     // Expand dimension -3
100*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDims(inputShape, -3);
101*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
102*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 2);
103*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 1);
104*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 3);
105*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 4);
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsNegativeAxis4Test")
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker     // Expand dimension -4
113*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDims(inputShape, -4);
114*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
115*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 1);
116*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 2);
117*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 3);
118*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 4);
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsInvalidAxisTest")
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker     // Invalid expand dimension 4
126*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(ExpandDims(inputShape, 4), armnn::InvalidArgumentException);
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsInvalidNegativeAxisTest")
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker     // Invalid expand dimension -5
134*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException);
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsBy1Rank")
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker     // Expand by 1 dimension
142*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4);
143*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
144*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 1);
145*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 2);
146*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 3);
147*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 4);
148*89c4ff92SAndroid Build Coastguard Worker }
149*89c4ff92SAndroid Build Coastguard Worker 
150*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsBy2Ranks")
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 3, 4 });
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker     // Expand 2 dimensions
155*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4);
156*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
157*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 1);
158*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 1);
159*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 3);
160*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 4);
161*89c4ff92SAndroid Build Coastguard Worker }
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsBy3Ranks")
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 4 });
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker     // Expand 3 dimensions
168*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4);
169*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 4);
170*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 1);
171*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 1);
172*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 1);
173*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[3] == 4);
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker 
176*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsInvalidRankAmount")
177*89c4ff92SAndroid Build Coastguard Worker {
178*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker     // Don't expand because target rank is smaller than current rank
181*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 2);
182*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 3);
183*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 2);
184*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 3);
185*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 4);
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsToRankInvalidTensorShape")
189*89c4ff92SAndroid Build Coastguard Worker {
190*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 2, 3, 4 });
191*89c4ff92SAndroid Build Coastguard Worker 
192*89c4ff92SAndroid Build Coastguard Worker     // Throw exception because rank 6 tensors are unsupported by armnn
193*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(ExpandDimsToRank(inputShape, 6), armnn::InvalidArgumentException);
194*89c4ff92SAndroid Build Coastguard Worker }
195*89c4ff92SAndroid Build Coastguard Worker 
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsShapeAll1s")
198*89c4ff92SAndroid Build Coastguard Worker {
199*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 1, 1, 1 });
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker     // Reduce dimension 2
202*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ReduceDims(inputShape, 2);
203*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 2);
204*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 1);
205*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 1);
206*89c4ff92SAndroid Build Coastguard Worker }
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsShapeNotEnough1s")
209*89c4ff92SAndroid Build Coastguard Worker {
210*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 1, 2, 1 });
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker     // Reduce dimension 1
213*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ReduceDims(inputShape, 1);
214*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 2);
215*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 2);
216*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 1);
217*89c4ff92SAndroid Build Coastguard Worker }
218*89c4ff92SAndroid Build Coastguard Worker 
219*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsInfoAll1s")
220*89c4ff92SAndroid Build Coastguard Worker {
221*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputInfo({ 1, 1, 1 }, DataType::Float32);
222*89c4ff92SAndroid Build Coastguard Worker 
223*89c4ff92SAndroid Build Coastguard Worker     // Reduce dimension 2
224*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 2);
225*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputInfo.GetShape().GetNumDimensions() == 2);
226*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputInfo.GetShape()[0] == 1);
227*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputInfo.GetShape()[1] == 1);
228*89c4ff92SAndroid Build Coastguard Worker }
229*89c4ff92SAndroid Build Coastguard Worker 
230*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsInfoNotEnough1s")
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputInfo({ 1, 2, 1 }, DataType::Float32);
233*89c4ff92SAndroid Build Coastguard Worker 
234*89c4ff92SAndroid Build Coastguard Worker     // Reduce dimension 1
235*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 1);
236*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputInfo.GetNumDimensions() == 2);
237*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputInfo.GetShape()[0] == 2);
238*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputInfo.GetShape()[1] == 1);
239*89c4ff92SAndroid Build Coastguard Worker }
240*89c4ff92SAndroid Build Coastguard Worker 
241*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize")
242*89c4ff92SAndroid Build Coastguard Worker {
243*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape inputShape({ 1, 1, 1 });
244*89c4ff92SAndroid Build Coastguard Worker 
245*89c4ff92SAndroid Build Coastguard Worker     // Do not reduce because dimension does not exist
246*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape outputShape = ReduceDims(inputShape, 4);
247*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape.GetNumDimensions() == 3);
248*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[0] == 1);
249*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[1] == 1);
250*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputShape[2] == 1);
251*89c4ff92SAndroid Build Coastguard Worker }
252*89c4ff92SAndroid Build Coastguard Worker 
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayInvalidDataType")
255*89c4ff92SAndroid Build Coastguard Worker {
256*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 2, 3, 4 }, armnn::DataType::BFloat16);
257*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data {1,2,3,4,5,6,7,8,9,10};
258*89c4ff92SAndroid Build Coastguard Worker 
259*89c4ff92SAndroid Build Coastguard Worker     // Invalid argument
260*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(ToFloatArray(data, info), armnn::InvalidArgumentException);
261*89c4ff92SAndroid Build Coastguard Worker }
262*89c4ff92SAndroid Build Coastguard Worker 
263*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQSymmS8PerAxis")
264*89c4ff92SAndroid Build Coastguard Worker {
265*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
266*89c4ff92SAndroid Build Coastguard Worker     unsigned int quantizationDim = 1;
267*89c4ff92SAndroid Build Coastguard Worker 
268*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QSymmS8, quantizationScales, quantizationDim);
269*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
270*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 24.0f, -37.8f, -46.4f, -10.6f, -19.2f, -25.8f, -30.4f, -6.6f, -11.2f, -13.8f, -14.4f };
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
273*89c4ff92SAndroid Build Coastguard Worker 
274*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
275*89c4ff92SAndroid Build Coastguard Worker     {
276*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
277*89c4ff92SAndroid Build Coastguard Worker     }
278*89c4ff92SAndroid Build Coastguard Worker }
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQSymmS8")
281*89c4ff92SAndroid Build Coastguard Worker {
282*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QSymmS8, 0.1f);
283*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
284*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 12.0f, -12.6f, -11.6f, -10.6f, -9.6f, -8.6f, -7.6f, -6.6f,  -5.6f, -4.6f, -3.6f };
285*89c4ff92SAndroid Build Coastguard Worker 
286*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
287*89c4ff92SAndroid Build Coastguard Worker 
288*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
289*89c4ff92SAndroid Build Coastguard Worker     {
290*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
291*89c4ff92SAndroid Build Coastguard Worker     }
292*89c4ff92SAndroid Build Coastguard Worker }
293*89c4ff92SAndroid Build Coastguard Worker 
294*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQAsymmS8PerAxis")
295*89c4ff92SAndroid Build Coastguard Worker {
296*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
297*89c4ff92SAndroid Build Coastguard Worker     unsigned int quantizationDim = 1;
298*89c4ff92SAndroid Build Coastguard Worker 
299*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmS8, quantizationScales, quantizationDim);
300*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
301*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 24.0f, -37.8f, -46.4f, -10.6f, -19.2f, -25.8f, -30.4f, -6.6f, -11.2f, -13.8f, -14.4f };
302*89c4ff92SAndroid Build Coastguard Worker 
303*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
304*89c4ff92SAndroid Build Coastguard Worker 
305*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
306*89c4ff92SAndroid Build Coastguard Worker     {
307*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
308*89c4ff92SAndroid Build Coastguard Worker     }
309*89c4ff92SAndroid Build Coastguard Worker }
310*89c4ff92SAndroid Build Coastguard Worker 
311*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQAsymmS8")
312*89c4ff92SAndroid Build Coastguard Worker {
313*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmS8, 0.1f);
314*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
315*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 12.0f, -12.6f, -11.6f, -10.6f, -9.6f, -8.6f, -7.6f, -6.6f,  -5.6f, -4.6f, -3.6f };
316*89c4ff92SAndroid Build Coastguard Worker 
317*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
318*89c4ff92SAndroid Build Coastguard Worker 
319*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
320*89c4ff92SAndroid Build Coastguard Worker     {
321*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
322*89c4ff92SAndroid Build Coastguard Worker     }
323*89c4ff92SAndroid Build Coastguard Worker }
324*89c4ff92SAndroid Build Coastguard Worker 
325*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQASymmU8PerAxis")
326*89c4ff92SAndroid Build Coastguard Worker {
327*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
328*89c4ff92SAndroid Build Coastguard Worker     unsigned int quantizationDim = 1;
329*89c4ff92SAndroid Build Coastguard Worker 
330*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmU8, quantizationScales, quantizationDim);
331*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220 };
332*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f };
333*89c4ff92SAndroid Build Coastguard Worker 
334*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
335*89c4ff92SAndroid Build Coastguard Worker 
336*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
337*89c4ff92SAndroid Build Coastguard Worker     {
338*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
339*89c4ff92SAndroid Build Coastguard Worker     }
340*89c4ff92SAndroid Build Coastguard Worker }
341*89c4ff92SAndroid Build Coastguard Worker 
342*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQAsymmU8")
343*89c4ff92SAndroid Build Coastguard Worker {
344*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmU8, 0.1f);
345*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220 };
346*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f };
347*89c4ff92SAndroid Build Coastguard Worker 
348*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
349*89c4ff92SAndroid Build Coastguard Worker 
350*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
351*89c4ff92SAndroid Build Coastguard Worker     {
352*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
353*89c4ff92SAndroid Build Coastguard Worker     }
354*89c4ff92SAndroid Build Coastguard Worker }
355*89c4ff92SAndroid Build Coastguard Worker 
356*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArraySigned32PerAxis")
357*89c4ff92SAndroid Build Coastguard Worker {
358*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
359*89c4ff92SAndroid Build Coastguard Worker     unsigned int quantizationDim = 1;
360*89c4ff92SAndroid Build Coastguard Worker 
361*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed32, quantizationScales, quantizationDim);
362*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 0, 0, 0, 120, 0, 0, 0, 130, 0, 0, 0, 140, 0, 0, 0, 150, 0, 0, 0, 160, 0, 0, 0,
363*89c4ff92SAndroid Build Coastguard Worker                                 170, 0, 0, 0, 180, 0, 0, 0, 190, 0, 0, 0, 200, 0, 0, 0, 210, 0, 0, 0, 220, 0, 0, 0 };
364*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f };
365*89c4ff92SAndroid Build Coastguard Worker 
366*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
367*89c4ff92SAndroid Build Coastguard Worker 
368*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
369*89c4ff92SAndroid Build Coastguard Worker     {
370*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
371*89c4ff92SAndroid Build Coastguard Worker     }
372*89c4ff92SAndroid Build Coastguard Worker }
373*89c4ff92SAndroid Build Coastguard Worker 
374*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArraySigned32")
375*89c4ff92SAndroid Build Coastguard Worker {
376*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed32, 0.1f);
377*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 0, 0, 0, 120, 0, 0, 0, 130, 0, 0, 0, 140, 0, 0, 0, 150, 0, 0, 0, 160, 0, 0, 0,
378*89c4ff92SAndroid Build Coastguard Worker                                 170, 0, 0, 0, 180, 0, 0, 0, 190, 0, 0, 0, 200, 0, 0, 0, 210, 0, 0, 0, 220, 0, 0, 0 };
379*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f };
380*89c4ff92SAndroid Build Coastguard Worker 
381*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
382*89c4ff92SAndroid Build Coastguard Worker 
383*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
384*89c4ff92SAndroid Build Coastguard Worker     {
385*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
386*89c4ff92SAndroid Build Coastguard Worker     }
387*89c4ff92SAndroid Build Coastguard Worker }
388*89c4ff92SAndroid Build Coastguard Worker 
389*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArraySigned64PerAxis")
390*89c4ff92SAndroid Build Coastguard Worker {
391*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
392*89c4ff92SAndroid Build Coastguard Worker     unsigned int quantizationDim = 1;
393*89c4ff92SAndroid Build Coastguard Worker 
394*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed64, quantizationScales, quantizationDim);
395*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 0, 0, 0, 0, 0, 0, 0, 120, 0, 0, 0, 0, 0, 0, 0, 130, 0, 0, 0, 0, 0, 0, 0,
396*89c4ff92SAndroid Build Coastguard Worker                                 140, 0, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0,
397*89c4ff92SAndroid Build Coastguard Worker                                 170, 0, 0, 0, 0, 0, 0, 0, 180, 0, 0, 0, 0, 0, 0, 0, 190, 0, 0, 0, 0, 0, 0, 0,
398*89c4ff92SAndroid Build Coastguard Worker                                 200, 0, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 0, 220, 0, 0, 0, 0, 0, 0, 0 };
399*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f };
400*89c4ff92SAndroid Build Coastguard Worker 
401*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
402*89c4ff92SAndroid Build Coastguard Worker 
403*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
404*89c4ff92SAndroid Build Coastguard Worker     {
405*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
406*89c4ff92SAndroid Build Coastguard Worker     }
407*89c4ff92SAndroid Build Coastguard Worker }
408*89c4ff92SAndroid Build Coastguard Worker 
409*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArraySigned64")
410*89c4ff92SAndroid Build Coastguard Worker {
411*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed64, 0.1f);
412*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> data { 100, 0, 0, 0, 0, 0, 0, 0, 120, 0, 0, 0, 0, 0, 0, 0, 130, 0, 0, 0, 0, 0, 0, 0,
413*89c4ff92SAndroid Build Coastguard Worker                                 140, 0, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0,
414*89c4ff92SAndroid Build Coastguard Worker                                 170, 0, 0, 0, 0, 0, 0, 0, 180, 0, 0, 0, 0, 0, 0, 0, 190, 0, 0, 0, 0, 0, 0, 0,
415*89c4ff92SAndroid Build Coastguard Worker                                 200, 0, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 0, 220, 0, 0, 0, 0, 0, 0, 0 };
416*89c4ff92SAndroid Build Coastguard Worker     float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f };
417*89c4ff92SAndroid Build Coastguard Worker 
418*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<float[]> result = ToFloatArray(data, info);
419*89c4ff92SAndroid Build Coastguard Worker 
420*89c4ff92SAndroid Build Coastguard Worker     for (uint i = 0; i < info.GetNumElements(); ++i)
421*89c4ff92SAndroid Build Coastguard Worker     {
422*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(result[i], doctest::Approx(expected[i]));
423*89c4ff92SAndroid Build Coastguard Worker     }
424*89c4ff92SAndroid Build Coastguard Worker }
425*89c4ff92SAndroid Build Coastguard Worker }
426