xref: /aosp_15_r20/external/armnn/src/armnn/test/TensorTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/Tensor.hpp>
7 #include <armnn/utility/IgnoreUnused.hpp>
8 
9 #include <doctest/doctest.h>
10 
11 using namespace armnn;
12 
13 TEST_SUITE("Tensor")
14 {
15 struct TensorInfoFixture
16 {
TensorInfoFixtureTensorInfoFixture17     TensorInfoFixture()
18     {
19         unsigned int sizes[] = {6,7,8,9};
20         m_TensorInfo = TensorInfo(4, sizes, DataType::Float32);
21     }
~TensorInfoFixtureTensorInfoFixture22     ~TensorInfoFixture() {};
23 
24     TensorInfo m_TensorInfo;
25 };
26 
27 TEST_CASE_FIXTURE(TensorInfoFixture, "ConstructShapeUsingListInitialization")
28 {
29     TensorShape listInitializedShape{ 6, 7, 8, 9 };
30     CHECK(listInitializedShape == m_TensorInfo.GetShape());
31 }
32 
33 TEST_CASE_FIXTURE(TensorInfoFixture, "ConstructTensorInfo")
34 {
35     CHECK(m_TensorInfo.GetNumDimensions() == 4);
36     CHECK(m_TensorInfo.GetShape()[0] == 6); // <= Outer most
37     CHECK(m_TensorInfo.GetShape()[1] == 7);
38     CHECK(m_TensorInfo.GetShape()[2] == 8);
39     CHECK(m_TensorInfo.GetShape()[3] == 9);     // <= Inner most
40 }
41 
42 TEST_CASE_FIXTURE(TensorInfoFixture, "CopyConstructTensorInfo")
43 {
44     TensorInfo copyConstructed(m_TensorInfo);
45     CHECK(copyConstructed.GetNumDimensions() == 4);
46     CHECK(copyConstructed.GetShape()[0] == 6);
47     CHECK(copyConstructed.GetShape()[1] == 7);
48     CHECK(copyConstructed.GetShape()[2] == 8);
49     CHECK(copyConstructed.GetShape()[3] == 9);
50 }
51 
52 TEST_CASE_FIXTURE(TensorInfoFixture, "TensorInfoEquality")
53 {
54     TensorInfo copyConstructed(m_TensorInfo);
55     CHECK(copyConstructed == m_TensorInfo);
56 }
57 
58 TEST_CASE_FIXTURE(TensorInfoFixture, "TensorInfoInequality")
59 {
60     TensorInfo other;
61     unsigned int sizes[] = {2,3,4,5};
62     other = TensorInfo(4, sizes, DataType::Float32);
63 
64     CHECK(other != m_TensorInfo);
65 }
66 
67 TEST_CASE_FIXTURE(TensorInfoFixture, "TensorInfoAssignmentOperator")
68 {
69     TensorInfo copy;
70     copy = m_TensorInfo;
71     CHECK(copy == m_TensorInfo);
72 }
73 
74 TEST_CASE("CopyNoQuantizationTensorInfo")
75 {
76     TensorInfo infoA;
77     infoA.SetShape({ 5, 6, 7, 8 });
78     infoA.SetDataType(DataType::QAsymmU8);
79 
80     TensorInfo infoB;
81     infoB.SetShape({ 5, 6, 7, 8 });
82     infoB.SetDataType(DataType::QAsymmU8);
83     infoB.SetQuantizationScale(10.0f);
84     infoB.SetQuantizationOffset(5);
85     infoB.SetQuantizationDim(Optional<unsigned int>(1));
86 
87     CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
88     CHECK((infoA.GetDataType() == DataType::QAsymmU8));
89     CHECK(infoA.GetQuantizationScale() == 1);
90     CHECK(infoA.GetQuantizationOffset() == 0);
91     CHECK(!infoA.GetQuantizationDim().has_value());
92 
93     CHECK(infoA != infoB);
94     infoA = infoB;
95     CHECK(infoA == infoB);
96 
97     CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
98     CHECK((infoA.GetDataType() == DataType::QAsymmU8));
99     CHECK(infoA.GetQuantizationScale() == 10.0f);
100     CHECK(infoA.GetQuantizationOffset() == 5);
101     CHECK(infoA.GetQuantizationDim().value() == 1);
102 }
103 
104 TEST_CASE("CopyDifferentQuantizationTensorInfo")
105 {
106     TensorInfo infoA;
107     infoA.SetShape({ 5, 6, 7, 8 });
108     infoA.SetDataType(DataType::QAsymmU8);
109     infoA.SetQuantizationScale(10.0f);
110     infoA.SetQuantizationOffset(5);
111     infoA.SetQuantizationDim(Optional<unsigned int>(1));
112 
113     TensorInfo infoB;
114     infoB.SetShape({ 5, 6, 7, 8 });
115     infoB.SetDataType(DataType::QAsymmU8);
116     infoB.SetQuantizationScale(11.0f);
117     infoB.SetQuantizationOffset(6);
118     infoB.SetQuantizationDim(Optional<unsigned int>(2));
119 
120     CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
121     CHECK((infoA.GetDataType() == DataType::QAsymmU8));
122     CHECK(infoA.GetQuantizationScale() == 10.0f);
123     CHECK(infoA.GetQuantizationOffset() == 5);
124     CHECK(infoA.GetQuantizationDim().value() == 1);
125 
126     CHECK(infoA != infoB);
127     infoA = infoB;
128     CHECK(infoA == infoB);
129 
130     CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
131     CHECK((infoA.GetDataType() == DataType::QAsymmU8));
132     CHECK(infoA.GetQuantizationScale() == 11.0f);
133     CHECK(infoA.GetQuantizationOffset() == 6);
134     CHECK(infoA.GetQuantizationDim().value() == 2);
135 }
136 
CheckTensor(const ConstTensor & t)137 void CheckTensor(const ConstTensor& t)
138 {
139     t.GetInfo();
140 }
141 
142 TEST_CASE("TensorVsConstTensor")
143 {
144     int mutableDatum = 2;
145     const int immutableDatum = 3;
146 
147     armnn::Tensor uninitializedTensor;
148     uninitializedTensor.GetInfo().SetConstant(true);
149     armnn::ConstTensor uninitializedTensor2;
150 
151     uninitializedTensor2 = uninitializedTensor;
152 
153     armnn::TensorInfo emptyTensorInfo;
154     emptyTensorInfo.SetConstant(true);
155     armnn::Tensor t(emptyTensorInfo, &mutableDatum);
156     armnn::ConstTensor ct(emptyTensorInfo, &immutableDatum);
157 
158     // Checks that both Tensor and ConstTensor can be passed as a ConstTensor.
159     CheckTensor(t);
160     CheckTensor(ct);
161 }
162 
163 TEST_CASE("ConstTensor_EmptyConstructorTensorInfoSet")
164 {
165     armnn::ConstTensor t;
166     CHECK(t.GetInfo().IsConstant() == true);
167 }
168 
169 TEST_CASE("ConstTensor_TensorInfoNotConstantError")
170 {
171     armnn::TensorInfo tensorInfo ({ 1 }, armnn::DataType::Float32);
172     std::vector<float> tensorData =  { 1.0f };
173     try
174     {
175         armnn::ConstTensor ct(tensorInfo, tensorData);
176         FAIL("InvalidArgumentException should have been thrown");
177     }
178     catch(const InvalidArgumentException& exc)
179     {
180         CHECK(strcmp(exc.what(), "Invalid attempt to construct ConstTensor from non-constant TensorInfo.") == 0);
181     }
182 }
183 
184 TEST_CASE("PassTensorToConstTensor_TensorInfoNotConstantError")
185 {
186     try
187     {
188         armnn::ConstTensor t = ConstTensor(Tensor());
189         FAIL("InvalidArgumentException should have been thrown");
190     }
191     catch(const InvalidArgumentException& exc)
192     {
193         CHECK(strcmp(exc.what(), "Invalid attempt to construct ConstTensor from "
194                                  "Tensor due to non-constant TensorInfo") == 0);
195     }
196 }
197 
198 TEST_CASE("ModifyTensorInfo")
199 {
200     TensorInfo info;
201     info.SetShape({ 5, 6, 7, 8 });
202     CHECK((info.GetShape() == TensorShape({ 5, 6, 7, 8 })));
203     info.SetDataType(DataType::QAsymmU8);
204     CHECK((info.GetDataType() == DataType::QAsymmU8));
205     info.SetQuantizationScale(10.0f);
206     CHECK(info.GetQuantizationScale() == 10.0f);
207     info.SetQuantizationOffset(5);
208     CHECK(info.GetQuantizationOffset() == 5);
209 }
210 
211 TEST_CASE("TensorShapeOperatorBrackets")
212 {
213     const TensorShape constShape({0,1,2,3});
214     TensorShape shape({0,1,2,3});
215 
216     // Checks version of operator[] which returns an unsigned int.
217     CHECK(shape[2] == 2);
218     shape[2] = 20;
219     CHECK(shape[2] == 20);
220 
221     // Checks the version of operator[] which returns a reference.
222     CHECK(constShape[2] == 2);
223 }
224 
225 TEST_CASE("TensorInfoPerAxisQuantization")
226 {
227     // Old constructor
228     TensorInfo tensorInfo0({ 1, 1 }, DataType::Float32, 2.0f, 1);
229     CHECK(!tensorInfo0.HasMultipleQuantizationScales());
230     CHECK(tensorInfo0.GetQuantizationScale() == 2.0f);
231     CHECK(tensorInfo0.GetQuantizationOffset() == 1);
232     CHECK(tensorInfo0.GetQuantizationScales()[0] == 2.0f);
233     CHECK(!tensorInfo0.GetQuantizationDim().has_value());
234 
235     // Set per-axis quantization scales
236     std::vector<float> perAxisScales{ 3.0f, 4.0f };
237     tensorInfo0.SetQuantizationScales(perAxisScales);
238     CHECK(tensorInfo0.HasMultipleQuantizationScales());
239     CHECK(tensorInfo0.GetQuantizationScales() == perAxisScales);
240 
241     // Set per-tensor quantization scale
242     tensorInfo0.SetQuantizationScale(5.0f);
243     CHECK(!tensorInfo0.HasMultipleQuantizationScales());
244     CHECK(tensorInfo0.GetQuantizationScales()[0] == 5.0f);
245 
246     // Set quantization offset
247     tensorInfo0.SetQuantizationDim(Optional<unsigned int>(1));
248     CHECK(tensorInfo0.GetQuantizationDim().value() == 1);
249 
250     // New constructor
251     perAxisScales = { 6.0f, 7.0f };
252     TensorInfo tensorInfo1({ 1, 1 }, DataType::Float32, perAxisScales, 1);
253     CHECK(tensorInfo1.HasMultipleQuantizationScales());
254     CHECK(tensorInfo1.GetQuantizationOffset() == 0);
255     CHECK(tensorInfo1.GetQuantizationScales() == perAxisScales);
256     CHECK(tensorInfo1.GetQuantizationDim().value() == 1);
257 }
258 
259 TEST_CASE("TensorShape_scalar")
260 {
261     float mutableDatum = 3.1416f;
262 
263     const armnn::TensorShape shape  (armnn::Dimensionality::Scalar );
264     armnn::TensorInfo        info   ( shape, DataType::Float32 );
265     const armnn::Tensor      tensor ( info, &mutableDatum );
266 
267     CHECK(armnn::Dimensionality::Scalar == shape.GetDimensionality());
268     float scalarValue = *reinterpret_cast<float*>(tensor.GetMemoryArea());
269     CHECK_MESSAGE(mutableDatum == scalarValue, "Scalar value is " << scalarValue);
270 
271     armnn::TensorShape shape_equal;
272     armnn::TensorShape shape_different;
273     shape_equal = shape;
274     CHECK(shape_equal == shape);
275     CHECK(shape_different != shape);
276     CHECK_MESSAGE(1 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
277     CHECK_MESSAGE(1 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
278     CHECK(true == shape.GetDimensionSpecificity(0));
279     CHECK(shape.AreAllDimensionsSpecified());
280     CHECK(shape.IsAtLeastOneDimensionSpecified());
281 
282     CHECK(1 == shape[0]);
283     CHECK(1 == tensor.GetShape()[0]);
284     CHECK(1 == tensor.GetInfo().GetShape()[0]);
285     CHECK_THROWS_AS( shape[1], InvalidArgumentException );
286 
287     float newMutableDatum  = 42.f;
288     std::memcpy(tensor.GetMemoryArea(), &newMutableDatum, sizeof(float));
289     scalarValue = *reinterpret_cast<float*>(tensor.GetMemoryArea());
290     CHECK_MESSAGE(newMutableDatum == scalarValue, "Scalar value is " << scalarValue);
291 }
292 
293 TEST_CASE("TensorShape_DynamicTensorType1_unknownNumberDimensions")
294 {
295     float       mutableDatum  = 3.1416f;
296 
297     armnn::TensorShape shape  (armnn::Dimensionality::NotSpecified );
298     armnn::TensorInfo  info   ( shape, DataType::Float32 );
299     armnn::Tensor      tensor ( info, &mutableDatum );
300 
301     CHECK(armnn::Dimensionality::NotSpecified == shape.GetDimensionality());
302     CHECK_THROWS_AS( shape[0], InvalidArgumentException );
303     CHECK_THROWS_AS( shape.GetNumElements(), InvalidArgumentException );
304     CHECK_THROWS_AS( shape.GetNumDimensions(), InvalidArgumentException );
305 
306     armnn::TensorShape shape_equal;
307     armnn::TensorShape shape_different;
308     shape_equal = shape;
309     CHECK(shape_equal == shape);
310     CHECK(shape_different != shape);
311 }
312 
313 TEST_CASE("TensorShape_DynamicTensorType1_unknownAllDimensionsSizes")
314 {
315     float       mutableDatum  = 3.1416f;
316 
317     armnn::TensorShape shape  ( 3, false );
318     armnn::TensorInfo  info   ( shape, DataType::Float32 );
319     armnn::Tensor      tensor ( info, &mutableDatum );
320 
321     CHECK(armnn::Dimensionality::Specified == shape.GetDimensionality());
322     CHECK_MESSAGE(0 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
323     CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
324     CHECK(false == shape.GetDimensionSpecificity(0));
325     CHECK(false == shape.GetDimensionSpecificity(1));
326     CHECK(false == shape.GetDimensionSpecificity(2));
327     CHECK(!shape.AreAllDimensionsSpecified());
328     CHECK(!shape.IsAtLeastOneDimensionSpecified());
329 
330     armnn::TensorShape shape_equal;
331     armnn::TensorShape shape_different;
332     shape_equal = shape;
333     CHECK(shape_equal == shape);
334     CHECK(shape_different != shape);
335 }
336 
337 TEST_CASE("TensorShape_DynamicTensorType1_unknownSomeDimensionsSizes")
338 {
339     std::vector<float> mutableDatum  { 42.f, 42.f, 42.f,
340                                        0.0f, 0.1f, 0.2f };
341 
342     armnn::TensorShape shape         ( {2, 0, 3}, {true, false, true} );
343     armnn::TensorInfo  info          ( shape, DataType::Float32 );
344     armnn::Tensor      tensor        ( info, &mutableDatum );
345 
346     CHECK(armnn::Dimensionality::Specified == shape.GetDimensionality());
347     CHECK_MESSAGE(6 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
348     CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
349     CHECK(true  == shape.GetDimensionSpecificity(0));
350     CHECK(false == shape.GetDimensionSpecificity(1));
351     CHECK(true  == shape.GetDimensionSpecificity(2));
352     CHECK(!shape.AreAllDimensionsSpecified());
353     CHECK(shape.IsAtLeastOneDimensionSpecified());
354 
355     CHECK_THROWS_AS(shape[1], InvalidArgumentException);
356     CHECK_THROWS_AS(tensor.GetShape()[1], InvalidArgumentException);
357     CHECK_THROWS_AS(tensor.GetInfo().GetShape()[1], InvalidArgumentException);
358 
359     CHECK(2 == shape[0]);
360     CHECK(2 == tensor.GetShape()[0]);
361     CHECK(2 == tensor.GetInfo().GetShape()[0]);
362     CHECK_THROWS_AS( shape[1], InvalidArgumentException );
363 
364     CHECK(3 == shape[2]);
365     CHECK(3 == tensor.GetShape()[2]);
366     CHECK(3 == tensor.GetInfo().GetShape()[2]);
367 
368     armnn::TensorShape shape_equal;
369     armnn::TensorShape shape_different;
370     shape_equal = shape;
371     CHECK(shape_equal == shape);
372     CHECK(shape_different != shape);
373 }
374 
375 TEST_CASE("TensorShape_DynamicTensorType1_transitionFromUnknownToKnownDimensionsSizes")
376 {
377     std::vector<float> mutableDatum  { 42.f, 42.f, 42.f,
378                                        0.0f, 0.1f, 0.2f };
379 
380     armnn::TensorShape shape         (armnn::Dimensionality::NotSpecified );
381     armnn::TensorInfo  info          ( shape, DataType::Float32 );
382     armnn::Tensor      tensor        ( info, &mutableDatum );
383 
384     // Specify the number of dimensions
385     shape.SetNumDimensions(3);
386     CHECK(armnn::Dimensionality::Specified == shape.GetDimensionality());
387     CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
388     CHECK(false == shape.GetDimensionSpecificity(0));
389     CHECK(false == shape.GetDimensionSpecificity(1));
390     CHECK(false == shape.GetDimensionSpecificity(2));
391     CHECK(!shape.AreAllDimensionsSpecified());
392     CHECK(!shape.IsAtLeastOneDimensionSpecified());
393 
394     // Specify dimension 0 and 2.
395     shape.SetDimensionSize(0, 2);
396     shape.SetDimensionSize(2, 3);
397     CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
398     CHECK_MESSAGE(6 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
399     CHECK(true  == shape.GetDimensionSpecificity(0));
400     CHECK(false == shape.GetDimensionSpecificity(1));
401     CHECK(true  == shape.GetDimensionSpecificity(2));
402     CHECK(!shape.AreAllDimensionsSpecified());
403     CHECK(shape.IsAtLeastOneDimensionSpecified());
404 
405     info.SetShape(shape);
406     armnn::Tensor tensor2( info, &mutableDatum );
407     CHECK(2 == shape[0]);
408     CHECK(2 == tensor2.GetShape()[0]);
409     CHECK(2 == tensor2.GetInfo().GetShape()[0]);
410 
411     CHECK_THROWS_AS(shape[1], InvalidArgumentException);
412     CHECK_THROWS_AS(tensor.GetShape()[1], InvalidArgumentException);
413     CHECK_THROWS_AS(tensor.GetInfo().GetShape()[1], InvalidArgumentException);
414 
415     CHECK(3 == shape[2]);
416     CHECK(3 == tensor2.GetShape()[2]);
417     CHECK(3 == tensor2.GetInfo().GetShape()[2]);
418 
419     armnn::TensorShape shape_equal;
420     armnn::TensorShape shape_different;
421     shape_equal = shape;
422     CHECK(shape_equal == shape);
423     CHECK(shape_different != shape);
424 
425     // Specify dimension 1.
426     shape.SetDimensionSize(1, 5);
427     CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
428     CHECK_MESSAGE(30 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
429     CHECK(true  == shape.GetDimensionSpecificity(0));
430     CHECK(true  == shape.GetDimensionSpecificity(1));
431     CHECK(true  == shape.GetDimensionSpecificity(2));
432     CHECK(shape.AreAllDimensionsSpecified());
433     CHECK(shape.IsAtLeastOneDimensionSpecified());
434 }
435 
436 TEST_CASE("Tensor_emptyConstructors")
437 {
438     auto shape = armnn::TensorShape();
439     CHECK_MESSAGE( 0 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
440     CHECK_MESSAGE( 0 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
441     CHECK( armnn::Dimensionality::Specified == shape.GetDimensionality());
442     CHECK( shape.AreAllDimensionsSpecified());
443     CHECK_THROWS_AS( shape[0], InvalidArgumentException );
444 
445     auto tensor = armnn::Tensor();
446     CHECK_MESSAGE( 0 == tensor.GetNumDimensions(), "Number of dimensions is " << tensor.GetNumDimensions());
447     CHECK_MESSAGE( 0 == tensor.GetNumElements(), "Number of elements is " << tensor.GetNumElements());
448     CHECK_MESSAGE( 0 == tensor.GetShape().GetNumDimensions(), "Number of dimensions is " <<
449                         tensor.GetShape().GetNumDimensions());
450     CHECK_MESSAGE( 0 == tensor.GetShape().GetNumElements(), "Number of dimensions is " <<
451                         tensor.GetShape().GetNumElements());
452     CHECK( armnn::Dimensionality::Specified == tensor.GetShape().GetDimensionality());
453     CHECK( tensor.GetShape().AreAllDimensionsSpecified());
454     CHECK_THROWS_AS( tensor.GetShape()[0], InvalidArgumentException );
455 }
456 }
457