1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <armnn/Tensor.hpp>
7 #include <armnn/utility/IgnoreUnused.hpp>
8
9 #include <doctest/doctest.h>
10
11 using namespace armnn;
12
13 TEST_SUITE("Tensor")
14 {
15 struct TensorInfoFixture
16 {
TensorInfoFixtureTensorInfoFixture17 TensorInfoFixture()
18 {
19 unsigned int sizes[] = {6,7,8,9};
20 m_TensorInfo = TensorInfo(4, sizes, DataType::Float32);
21 }
~TensorInfoFixtureTensorInfoFixture22 ~TensorInfoFixture() {};
23
24 TensorInfo m_TensorInfo;
25 };
26
27 TEST_CASE_FIXTURE(TensorInfoFixture, "ConstructShapeUsingListInitialization")
28 {
29 TensorShape listInitializedShape{ 6, 7, 8, 9 };
30 CHECK(listInitializedShape == m_TensorInfo.GetShape());
31 }
32
33 TEST_CASE_FIXTURE(TensorInfoFixture, "ConstructTensorInfo")
34 {
35 CHECK(m_TensorInfo.GetNumDimensions() == 4);
36 CHECK(m_TensorInfo.GetShape()[0] == 6); // <= Outer most
37 CHECK(m_TensorInfo.GetShape()[1] == 7);
38 CHECK(m_TensorInfo.GetShape()[2] == 8);
39 CHECK(m_TensorInfo.GetShape()[3] == 9); // <= Inner most
40 }
41
42 TEST_CASE_FIXTURE(TensorInfoFixture, "CopyConstructTensorInfo")
43 {
44 TensorInfo copyConstructed(m_TensorInfo);
45 CHECK(copyConstructed.GetNumDimensions() == 4);
46 CHECK(copyConstructed.GetShape()[0] == 6);
47 CHECK(copyConstructed.GetShape()[1] == 7);
48 CHECK(copyConstructed.GetShape()[2] == 8);
49 CHECK(copyConstructed.GetShape()[3] == 9);
50 }
51
52 TEST_CASE_FIXTURE(TensorInfoFixture, "TensorInfoEquality")
53 {
54 TensorInfo copyConstructed(m_TensorInfo);
55 CHECK(copyConstructed == m_TensorInfo);
56 }
57
58 TEST_CASE_FIXTURE(TensorInfoFixture, "TensorInfoInequality")
59 {
60 TensorInfo other;
61 unsigned int sizes[] = {2,3,4,5};
62 other = TensorInfo(4, sizes, DataType::Float32);
63
64 CHECK(other != m_TensorInfo);
65 }
66
67 TEST_CASE_FIXTURE(TensorInfoFixture, "TensorInfoAssignmentOperator")
68 {
69 TensorInfo copy;
70 copy = m_TensorInfo;
71 CHECK(copy == m_TensorInfo);
72 }
73
74 TEST_CASE("CopyNoQuantizationTensorInfo")
75 {
76 TensorInfo infoA;
77 infoA.SetShape({ 5, 6, 7, 8 });
78 infoA.SetDataType(DataType::QAsymmU8);
79
80 TensorInfo infoB;
81 infoB.SetShape({ 5, 6, 7, 8 });
82 infoB.SetDataType(DataType::QAsymmU8);
83 infoB.SetQuantizationScale(10.0f);
84 infoB.SetQuantizationOffset(5);
85 infoB.SetQuantizationDim(Optional<unsigned int>(1));
86
87 CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
88 CHECK((infoA.GetDataType() == DataType::QAsymmU8));
89 CHECK(infoA.GetQuantizationScale() == 1);
90 CHECK(infoA.GetQuantizationOffset() == 0);
91 CHECK(!infoA.GetQuantizationDim().has_value());
92
93 CHECK(infoA != infoB);
94 infoA = infoB;
95 CHECK(infoA == infoB);
96
97 CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
98 CHECK((infoA.GetDataType() == DataType::QAsymmU8));
99 CHECK(infoA.GetQuantizationScale() == 10.0f);
100 CHECK(infoA.GetQuantizationOffset() == 5);
101 CHECK(infoA.GetQuantizationDim().value() == 1);
102 }
103
104 TEST_CASE("CopyDifferentQuantizationTensorInfo")
105 {
106 TensorInfo infoA;
107 infoA.SetShape({ 5, 6, 7, 8 });
108 infoA.SetDataType(DataType::QAsymmU8);
109 infoA.SetQuantizationScale(10.0f);
110 infoA.SetQuantizationOffset(5);
111 infoA.SetQuantizationDim(Optional<unsigned int>(1));
112
113 TensorInfo infoB;
114 infoB.SetShape({ 5, 6, 7, 8 });
115 infoB.SetDataType(DataType::QAsymmU8);
116 infoB.SetQuantizationScale(11.0f);
117 infoB.SetQuantizationOffset(6);
118 infoB.SetQuantizationDim(Optional<unsigned int>(2));
119
120 CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
121 CHECK((infoA.GetDataType() == DataType::QAsymmU8));
122 CHECK(infoA.GetQuantizationScale() == 10.0f);
123 CHECK(infoA.GetQuantizationOffset() == 5);
124 CHECK(infoA.GetQuantizationDim().value() == 1);
125
126 CHECK(infoA != infoB);
127 infoA = infoB;
128 CHECK(infoA == infoB);
129
130 CHECK((infoA.GetShape() == TensorShape({ 5, 6, 7, 8 })));
131 CHECK((infoA.GetDataType() == DataType::QAsymmU8));
132 CHECK(infoA.GetQuantizationScale() == 11.0f);
133 CHECK(infoA.GetQuantizationOffset() == 6);
134 CHECK(infoA.GetQuantizationDim().value() == 2);
135 }
136
CheckTensor(const ConstTensor & t)137 void CheckTensor(const ConstTensor& t)
138 {
139 t.GetInfo();
140 }
141
142 TEST_CASE("TensorVsConstTensor")
143 {
144 int mutableDatum = 2;
145 const int immutableDatum = 3;
146
147 armnn::Tensor uninitializedTensor;
148 uninitializedTensor.GetInfo().SetConstant(true);
149 armnn::ConstTensor uninitializedTensor2;
150
151 uninitializedTensor2 = uninitializedTensor;
152
153 armnn::TensorInfo emptyTensorInfo;
154 emptyTensorInfo.SetConstant(true);
155 armnn::Tensor t(emptyTensorInfo, &mutableDatum);
156 armnn::ConstTensor ct(emptyTensorInfo, &immutableDatum);
157
158 // Checks that both Tensor and ConstTensor can be passed as a ConstTensor.
159 CheckTensor(t);
160 CheckTensor(ct);
161 }
162
163 TEST_CASE("ConstTensor_EmptyConstructorTensorInfoSet")
164 {
165 armnn::ConstTensor t;
166 CHECK(t.GetInfo().IsConstant() == true);
167 }
168
169 TEST_CASE("ConstTensor_TensorInfoNotConstantError")
170 {
171 armnn::TensorInfo tensorInfo ({ 1 }, armnn::DataType::Float32);
172 std::vector<float> tensorData = { 1.0f };
173 try
174 {
175 armnn::ConstTensor ct(tensorInfo, tensorData);
176 FAIL("InvalidArgumentException should have been thrown");
177 }
178 catch(const InvalidArgumentException& exc)
179 {
180 CHECK(strcmp(exc.what(), "Invalid attempt to construct ConstTensor from non-constant TensorInfo.") == 0);
181 }
182 }
183
184 TEST_CASE("PassTensorToConstTensor_TensorInfoNotConstantError")
185 {
186 try
187 {
188 armnn::ConstTensor t = ConstTensor(Tensor());
189 FAIL("InvalidArgumentException should have been thrown");
190 }
191 catch(const InvalidArgumentException& exc)
192 {
193 CHECK(strcmp(exc.what(), "Invalid attempt to construct ConstTensor from "
194 "Tensor due to non-constant TensorInfo") == 0);
195 }
196 }
197
198 TEST_CASE("ModifyTensorInfo")
199 {
200 TensorInfo info;
201 info.SetShape({ 5, 6, 7, 8 });
202 CHECK((info.GetShape() == TensorShape({ 5, 6, 7, 8 })));
203 info.SetDataType(DataType::QAsymmU8);
204 CHECK((info.GetDataType() == DataType::QAsymmU8));
205 info.SetQuantizationScale(10.0f);
206 CHECK(info.GetQuantizationScale() == 10.0f);
207 info.SetQuantizationOffset(5);
208 CHECK(info.GetQuantizationOffset() == 5);
209 }
210
211 TEST_CASE("TensorShapeOperatorBrackets")
212 {
213 const TensorShape constShape({0,1,2,3});
214 TensorShape shape({0,1,2,3});
215
216 // Checks version of operator[] which returns an unsigned int.
217 CHECK(shape[2] == 2);
218 shape[2] = 20;
219 CHECK(shape[2] == 20);
220
221 // Checks the version of operator[] which returns a reference.
222 CHECK(constShape[2] == 2);
223 }
224
225 TEST_CASE("TensorInfoPerAxisQuantization")
226 {
227 // Old constructor
228 TensorInfo tensorInfo0({ 1, 1 }, DataType::Float32, 2.0f, 1);
229 CHECK(!tensorInfo0.HasMultipleQuantizationScales());
230 CHECK(tensorInfo0.GetQuantizationScale() == 2.0f);
231 CHECK(tensorInfo0.GetQuantizationOffset() == 1);
232 CHECK(tensorInfo0.GetQuantizationScales()[0] == 2.0f);
233 CHECK(!tensorInfo0.GetQuantizationDim().has_value());
234
235 // Set per-axis quantization scales
236 std::vector<float> perAxisScales{ 3.0f, 4.0f };
237 tensorInfo0.SetQuantizationScales(perAxisScales);
238 CHECK(tensorInfo0.HasMultipleQuantizationScales());
239 CHECK(tensorInfo0.GetQuantizationScales() == perAxisScales);
240
241 // Set per-tensor quantization scale
242 tensorInfo0.SetQuantizationScale(5.0f);
243 CHECK(!tensorInfo0.HasMultipleQuantizationScales());
244 CHECK(tensorInfo0.GetQuantizationScales()[0] == 5.0f);
245
246 // Set quantization offset
247 tensorInfo0.SetQuantizationDim(Optional<unsigned int>(1));
248 CHECK(tensorInfo0.GetQuantizationDim().value() == 1);
249
250 // New constructor
251 perAxisScales = { 6.0f, 7.0f };
252 TensorInfo tensorInfo1({ 1, 1 }, DataType::Float32, perAxisScales, 1);
253 CHECK(tensorInfo1.HasMultipleQuantizationScales());
254 CHECK(tensorInfo1.GetQuantizationOffset() == 0);
255 CHECK(tensorInfo1.GetQuantizationScales() == perAxisScales);
256 CHECK(tensorInfo1.GetQuantizationDim().value() == 1);
257 }
258
259 TEST_CASE("TensorShape_scalar")
260 {
261 float mutableDatum = 3.1416f;
262
263 const armnn::TensorShape shape (armnn::Dimensionality::Scalar );
264 armnn::TensorInfo info ( shape, DataType::Float32 );
265 const armnn::Tensor tensor ( info, &mutableDatum );
266
267 CHECK(armnn::Dimensionality::Scalar == shape.GetDimensionality());
268 float scalarValue = *reinterpret_cast<float*>(tensor.GetMemoryArea());
269 CHECK_MESSAGE(mutableDatum == scalarValue, "Scalar value is " << scalarValue);
270
271 armnn::TensorShape shape_equal;
272 armnn::TensorShape shape_different;
273 shape_equal = shape;
274 CHECK(shape_equal == shape);
275 CHECK(shape_different != shape);
276 CHECK_MESSAGE(1 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
277 CHECK_MESSAGE(1 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
278 CHECK(true == shape.GetDimensionSpecificity(0));
279 CHECK(shape.AreAllDimensionsSpecified());
280 CHECK(shape.IsAtLeastOneDimensionSpecified());
281
282 CHECK(1 == shape[0]);
283 CHECK(1 == tensor.GetShape()[0]);
284 CHECK(1 == tensor.GetInfo().GetShape()[0]);
285 CHECK_THROWS_AS( shape[1], InvalidArgumentException );
286
287 float newMutableDatum = 42.f;
288 std::memcpy(tensor.GetMemoryArea(), &newMutableDatum, sizeof(float));
289 scalarValue = *reinterpret_cast<float*>(tensor.GetMemoryArea());
290 CHECK_MESSAGE(newMutableDatum == scalarValue, "Scalar value is " << scalarValue);
291 }
292
293 TEST_CASE("TensorShape_DynamicTensorType1_unknownNumberDimensions")
294 {
295 float mutableDatum = 3.1416f;
296
297 armnn::TensorShape shape (armnn::Dimensionality::NotSpecified );
298 armnn::TensorInfo info ( shape, DataType::Float32 );
299 armnn::Tensor tensor ( info, &mutableDatum );
300
301 CHECK(armnn::Dimensionality::NotSpecified == shape.GetDimensionality());
302 CHECK_THROWS_AS( shape[0], InvalidArgumentException );
303 CHECK_THROWS_AS( shape.GetNumElements(), InvalidArgumentException );
304 CHECK_THROWS_AS( shape.GetNumDimensions(), InvalidArgumentException );
305
306 armnn::TensorShape shape_equal;
307 armnn::TensorShape shape_different;
308 shape_equal = shape;
309 CHECK(shape_equal == shape);
310 CHECK(shape_different != shape);
311 }
312
313 TEST_CASE("TensorShape_DynamicTensorType1_unknownAllDimensionsSizes")
314 {
315 float mutableDatum = 3.1416f;
316
317 armnn::TensorShape shape ( 3, false );
318 armnn::TensorInfo info ( shape, DataType::Float32 );
319 armnn::Tensor tensor ( info, &mutableDatum );
320
321 CHECK(armnn::Dimensionality::Specified == shape.GetDimensionality());
322 CHECK_MESSAGE(0 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
323 CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
324 CHECK(false == shape.GetDimensionSpecificity(0));
325 CHECK(false == shape.GetDimensionSpecificity(1));
326 CHECK(false == shape.GetDimensionSpecificity(2));
327 CHECK(!shape.AreAllDimensionsSpecified());
328 CHECK(!shape.IsAtLeastOneDimensionSpecified());
329
330 armnn::TensorShape shape_equal;
331 armnn::TensorShape shape_different;
332 shape_equal = shape;
333 CHECK(shape_equal == shape);
334 CHECK(shape_different != shape);
335 }
336
337 TEST_CASE("TensorShape_DynamicTensorType1_unknownSomeDimensionsSizes")
338 {
339 std::vector<float> mutableDatum { 42.f, 42.f, 42.f,
340 0.0f, 0.1f, 0.2f };
341
342 armnn::TensorShape shape ( {2, 0, 3}, {true, false, true} );
343 armnn::TensorInfo info ( shape, DataType::Float32 );
344 armnn::Tensor tensor ( info, &mutableDatum );
345
346 CHECK(armnn::Dimensionality::Specified == shape.GetDimensionality());
347 CHECK_MESSAGE(6 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
348 CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
349 CHECK(true == shape.GetDimensionSpecificity(0));
350 CHECK(false == shape.GetDimensionSpecificity(1));
351 CHECK(true == shape.GetDimensionSpecificity(2));
352 CHECK(!shape.AreAllDimensionsSpecified());
353 CHECK(shape.IsAtLeastOneDimensionSpecified());
354
355 CHECK_THROWS_AS(shape[1], InvalidArgumentException);
356 CHECK_THROWS_AS(tensor.GetShape()[1], InvalidArgumentException);
357 CHECK_THROWS_AS(tensor.GetInfo().GetShape()[1], InvalidArgumentException);
358
359 CHECK(2 == shape[0]);
360 CHECK(2 == tensor.GetShape()[0]);
361 CHECK(2 == tensor.GetInfo().GetShape()[0]);
362 CHECK_THROWS_AS( shape[1], InvalidArgumentException );
363
364 CHECK(3 == shape[2]);
365 CHECK(3 == tensor.GetShape()[2]);
366 CHECK(3 == tensor.GetInfo().GetShape()[2]);
367
368 armnn::TensorShape shape_equal;
369 armnn::TensorShape shape_different;
370 shape_equal = shape;
371 CHECK(shape_equal == shape);
372 CHECK(shape_different != shape);
373 }
374
375 TEST_CASE("TensorShape_DynamicTensorType1_transitionFromUnknownToKnownDimensionsSizes")
376 {
377 std::vector<float> mutableDatum { 42.f, 42.f, 42.f,
378 0.0f, 0.1f, 0.2f };
379
380 armnn::TensorShape shape (armnn::Dimensionality::NotSpecified );
381 armnn::TensorInfo info ( shape, DataType::Float32 );
382 armnn::Tensor tensor ( info, &mutableDatum );
383
384 // Specify the number of dimensions
385 shape.SetNumDimensions(3);
386 CHECK(armnn::Dimensionality::Specified == shape.GetDimensionality());
387 CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
388 CHECK(false == shape.GetDimensionSpecificity(0));
389 CHECK(false == shape.GetDimensionSpecificity(1));
390 CHECK(false == shape.GetDimensionSpecificity(2));
391 CHECK(!shape.AreAllDimensionsSpecified());
392 CHECK(!shape.IsAtLeastOneDimensionSpecified());
393
394 // Specify dimension 0 and 2.
395 shape.SetDimensionSize(0, 2);
396 shape.SetDimensionSize(2, 3);
397 CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
398 CHECK_MESSAGE(6 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
399 CHECK(true == shape.GetDimensionSpecificity(0));
400 CHECK(false == shape.GetDimensionSpecificity(1));
401 CHECK(true == shape.GetDimensionSpecificity(2));
402 CHECK(!shape.AreAllDimensionsSpecified());
403 CHECK(shape.IsAtLeastOneDimensionSpecified());
404
405 info.SetShape(shape);
406 armnn::Tensor tensor2( info, &mutableDatum );
407 CHECK(2 == shape[0]);
408 CHECK(2 == tensor2.GetShape()[0]);
409 CHECK(2 == tensor2.GetInfo().GetShape()[0]);
410
411 CHECK_THROWS_AS(shape[1], InvalidArgumentException);
412 CHECK_THROWS_AS(tensor.GetShape()[1], InvalidArgumentException);
413 CHECK_THROWS_AS(tensor.GetInfo().GetShape()[1], InvalidArgumentException);
414
415 CHECK(3 == shape[2]);
416 CHECK(3 == tensor2.GetShape()[2]);
417 CHECK(3 == tensor2.GetInfo().GetShape()[2]);
418
419 armnn::TensorShape shape_equal;
420 armnn::TensorShape shape_different;
421 shape_equal = shape;
422 CHECK(shape_equal == shape);
423 CHECK(shape_different != shape);
424
425 // Specify dimension 1.
426 shape.SetDimensionSize(1, 5);
427 CHECK_MESSAGE(3 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
428 CHECK_MESSAGE(30 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
429 CHECK(true == shape.GetDimensionSpecificity(0));
430 CHECK(true == shape.GetDimensionSpecificity(1));
431 CHECK(true == shape.GetDimensionSpecificity(2));
432 CHECK(shape.AreAllDimensionsSpecified());
433 CHECK(shape.IsAtLeastOneDimensionSpecified());
434 }
435
436 TEST_CASE("Tensor_emptyConstructors")
437 {
438 auto shape = armnn::TensorShape();
439 CHECK_MESSAGE( 0 == shape.GetNumDimensions(), "Number of dimensions is " << shape.GetNumDimensions());
440 CHECK_MESSAGE( 0 == shape.GetNumElements(), "Number of elements is " << shape.GetNumElements());
441 CHECK( armnn::Dimensionality::Specified == shape.GetDimensionality());
442 CHECK( shape.AreAllDimensionsSpecified());
443 CHECK_THROWS_AS( shape[0], InvalidArgumentException );
444
445 auto tensor = armnn::Tensor();
446 CHECK_MESSAGE( 0 == tensor.GetNumDimensions(), "Number of dimensions is " << tensor.GetNumDimensions());
447 CHECK_MESSAGE( 0 == tensor.GetNumElements(), "Number of elements is " << tensor.GetNumElements());
448 CHECK_MESSAGE( 0 == tensor.GetShape().GetNumDimensions(), "Number of dimensions is " <<
449 tensor.GetShape().GetNumDimensions());
450 CHECK_MESSAGE( 0 == tensor.GetShape().GetNumElements(), "Number of dimensions is " <<
451 tensor.GetShape().GetNumElements());
452 CHECK( armnn::Dimensionality::Specified == tensor.GetShape().GetDimensionality());
453 CHECK( tensor.GetShape().AreAllDimensionsSpecified());
454 CHECK_THROWS_AS( tensor.GetShape()[0], InvalidArgumentException );
455 }
456 }
457