1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019,2021-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp> 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TensorUtils.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 13*89c4ff92SAndroid Build Coastguard Worker using namespace armnnUtils; 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorUtilsSuite") 16*89c4ff92SAndroid Build Coastguard Worker { 17*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsAxis0Test") 18*89c4ff92SAndroid Build Coastguard Worker { 19*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker // Expand dimension 0 22*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDims(inputShape, 0); 23*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 24*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 1); 25*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 2); 26*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 3); 27*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 4); 28*89c4ff92SAndroid Build Coastguard Worker } 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsAxis1Test") 31*89c4ff92SAndroid Build Coastguard Worker { 32*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker // Expand dimension 1 35*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDims(inputShape, 1); 36*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 37*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 2); 38*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 1); 39*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 3); 40*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 4); 41*89c4ff92SAndroid Build Coastguard Worker } 42*89c4ff92SAndroid Build Coastguard Worker 43*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsAxis2Test") 44*89c4ff92SAndroid Build Coastguard Worker { 45*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 46*89c4ff92SAndroid Build Coastguard Worker 47*89c4ff92SAndroid Build Coastguard Worker // Expand dimension 2 48*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDims(inputShape, 2); 49*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 50*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 2); 51*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 3); 52*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 1); 53*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 4); 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsAxis3Test") 57*89c4ff92SAndroid Build Coastguard Worker { 58*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 59*89c4ff92SAndroid Build Coastguard Worker 60*89c4ff92SAndroid Build Coastguard Worker // Expand dimension 3 61*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDims(inputShape, 3); 62*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 63*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 2); 64*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 3); 65*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 4); 66*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 1); 67*89c4ff92SAndroid Build Coastguard Worker } 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsNegativeAxis1Test") 70*89c4ff92SAndroid Build Coastguard Worker { 71*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 72*89c4ff92SAndroid Build Coastguard Worker 73*89c4ff92SAndroid Build Coastguard Worker // Expand dimension -1 74*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDims(inputShape, -1); 75*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 76*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 2); 77*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 3); 78*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 4); 79*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 1); 80*89c4ff92SAndroid Build Coastguard Worker } 81*89c4ff92SAndroid Build Coastguard Worker 82*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsNegativeAxis2Test") 83*89c4ff92SAndroid Build Coastguard Worker { 84*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Worker // Expand dimension -2 87*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDims(inputShape, -2); 88*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 89*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 2); 90*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 3); 91*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 1); 92*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 4); 93*89c4ff92SAndroid Build Coastguard Worker } 94*89c4ff92SAndroid Build Coastguard Worker 95*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsNegativeAxis3Test") 96*89c4ff92SAndroid Build Coastguard Worker { 97*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 98*89c4ff92SAndroid Build Coastguard Worker 99*89c4ff92SAndroid Build Coastguard Worker // Expand dimension -3 100*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDims(inputShape, -3); 101*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 102*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 2); 103*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 1); 104*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 3); 105*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 4); 106*89c4ff92SAndroid Build Coastguard Worker } 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsNegativeAxis4Test") 109*89c4ff92SAndroid Build Coastguard Worker { 110*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 111*89c4ff92SAndroid Build Coastguard Worker 112*89c4ff92SAndroid Build Coastguard Worker // Expand dimension -4 113*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDims(inputShape, -4); 114*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 115*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 1); 116*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 2); 117*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 3); 118*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 4); 119*89c4ff92SAndroid Build Coastguard Worker } 120*89c4ff92SAndroid Build Coastguard Worker 121*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsInvalidAxisTest") 122*89c4ff92SAndroid Build Coastguard Worker { 123*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 124*89c4ff92SAndroid Build Coastguard Worker 125*89c4ff92SAndroid Build Coastguard Worker // Invalid expand dimension 4 126*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(ExpandDims(inputShape, 4), armnn::InvalidArgumentException); 127*89c4ff92SAndroid Build Coastguard Worker } 128*89c4ff92SAndroid Build Coastguard Worker 129*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsInvalidNegativeAxisTest") 130*89c4ff92SAndroid Build Coastguard Worker { 131*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 132*89c4ff92SAndroid Build Coastguard Worker 133*89c4ff92SAndroid Build Coastguard Worker // Invalid expand dimension -5 134*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException); 135*89c4ff92SAndroid Build Coastguard Worker } 136*89c4ff92SAndroid Build Coastguard Worker 137*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsBy1Rank") 138*89c4ff92SAndroid Build Coastguard Worker { 139*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 140*89c4ff92SAndroid Build Coastguard Worker 141*89c4ff92SAndroid Build Coastguard Worker // Expand by 1 dimension 142*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4); 143*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 144*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 1); 145*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 2); 146*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 3); 147*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 4); 148*89c4ff92SAndroid Build Coastguard Worker } 149*89c4ff92SAndroid Build Coastguard Worker 150*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsBy2Ranks") 151*89c4ff92SAndroid Build Coastguard Worker { 152*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 3, 4 }); 153*89c4ff92SAndroid Build Coastguard Worker 154*89c4ff92SAndroid Build Coastguard Worker // Expand 2 dimensions 155*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4); 156*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 157*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 1); 158*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 1); 159*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 3); 160*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 4); 161*89c4ff92SAndroid Build Coastguard Worker } 162*89c4ff92SAndroid Build Coastguard Worker 163*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsBy3Ranks") 164*89c4ff92SAndroid Build Coastguard Worker { 165*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 4 }); 166*89c4ff92SAndroid Build Coastguard Worker 167*89c4ff92SAndroid Build Coastguard Worker // Expand 3 dimensions 168*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4); 169*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 4); 170*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 1); 171*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 1); 172*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 1); 173*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[3] == 4); 174*89c4ff92SAndroid Build Coastguard Worker } 175*89c4ff92SAndroid Build Coastguard Worker 176*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsInvalidRankAmount") 177*89c4ff92SAndroid Build Coastguard Worker { 178*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 179*89c4ff92SAndroid Build Coastguard Worker 180*89c4ff92SAndroid Build Coastguard Worker // Don't expand because target rank is smaller than current rank 181*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 2); 182*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 3); 183*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 2); 184*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 3); 185*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 4); 186*89c4ff92SAndroid Build Coastguard Worker } 187*89c4ff92SAndroid Build Coastguard Worker 188*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ExpandDimsToRankInvalidTensorShape") 189*89c4ff92SAndroid Build Coastguard Worker { 190*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 2, 3, 4 }); 191*89c4ff92SAndroid Build Coastguard Worker 192*89c4ff92SAndroid Build Coastguard Worker // Throw exception because rank 6 tensors are unsupported by armnn 193*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(ExpandDimsToRank(inputShape, 6), armnn::InvalidArgumentException); 194*89c4ff92SAndroid Build Coastguard Worker } 195*89c4ff92SAndroid Build Coastguard Worker 196*89c4ff92SAndroid Build Coastguard Worker 197*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsShapeAll1s") 198*89c4ff92SAndroid Build Coastguard Worker { 199*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 1, 1, 1 }); 200*89c4ff92SAndroid Build Coastguard Worker 201*89c4ff92SAndroid Build Coastguard Worker // Reduce dimension 2 202*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ReduceDims(inputShape, 2); 203*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 2); 204*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 1); 205*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 1); 206*89c4ff92SAndroid Build Coastguard Worker } 207*89c4ff92SAndroid Build Coastguard Worker 208*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsShapeNotEnough1s") 209*89c4ff92SAndroid Build Coastguard Worker { 210*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 1, 2, 1 }); 211*89c4ff92SAndroid Build Coastguard Worker 212*89c4ff92SAndroid Build Coastguard Worker // Reduce dimension 1 213*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ReduceDims(inputShape, 1); 214*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 2); 215*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 2); 216*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 1); 217*89c4ff92SAndroid Build Coastguard Worker } 218*89c4ff92SAndroid Build Coastguard Worker 219*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsInfoAll1s") 220*89c4ff92SAndroid Build Coastguard Worker { 221*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo({ 1, 1, 1 }, DataType::Float32); 222*89c4ff92SAndroid Build Coastguard Worker 223*89c4ff92SAndroid Build Coastguard Worker // Reduce dimension 2 224*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 2); 225*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo.GetShape().GetNumDimensions() == 2); 226*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo.GetShape()[0] == 1); 227*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo.GetShape()[1] == 1); 228*89c4ff92SAndroid Build Coastguard Worker } 229*89c4ff92SAndroid Build Coastguard Worker 230*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsInfoNotEnough1s") 231*89c4ff92SAndroid Build Coastguard Worker { 232*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo({ 1, 2, 1 }, DataType::Float32); 233*89c4ff92SAndroid Build Coastguard Worker 234*89c4ff92SAndroid Build Coastguard Worker // Reduce dimension 1 235*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 1); 236*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo.GetNumDimensions() == 2); 237*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo.GetShape()[0] == 2); 238*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo.GetShape()[1] == 1); 239*89c4ff92SAndroid Build Coastguard Worker } 240*89c4ff92SAndroid Build Coastguard Worker 241*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize") 242*89c4ff92SAndroid Build Coastguard Worker { 243*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape inputShape({ 1, 1, 1 }); 244*89c4ff92SAndroid Build Coastguard Worker 245*89c4ff92SAndroid Build Coastguard Worker // Do not reduce because dimension does not exist 246*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape outputShape = ReduceDims(inputShape, 4); 247*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape.GetNumDimensions() == 3); 248*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[0] == 1); 249*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[1] == 1); 250*89c4ff92SAndroid Build Coastguard Worker CHECK(outputShape[2] == 1); 251*89c4ff92SAndroid Build Coastguard Worker } 252*89c4ff92SAndroid Build Coastguard Worker 253*89c4ff92SAndroid Build Coastguard Worker 254*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayInvalidDataType") 255*89c4ff92SAndroid Build Coastguard Worker { 256*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 2, 3, 4 }, armnn::DataType::BFloat16); 257*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data {1,2,3,4,5,6,7,8,9,10}; 258*89c4ff92SAndroid Build Coastguard Worker 259*89c4ff92SAndroid Build Coastguard Worker // Invalid argument 260*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(ToFloatArray(data, info), armnn::InvalidArgumentException); 261*89c4ff92SAndroid Build Coastguard Worker } 262*89c4ff92SAndroid Build Coastguard Worker 263*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQSymmS8PerAxis") 264*89c4ff92SAndroid Build Coastguard Worker { 265*89c4ff92SAndroid Build Coastguard Worker std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f }; 266*89c4ff92SAndroid Build Coastguard Worker unsigned int quantizationDim = 1; 267*89c4ff92SAndroid Build Coastguard Worker 268*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QSymmS8, quantizationScales, quantizationDim); 269*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 }; 270*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 24.0f, -37.8f, -46.4f, -10.6f, -19.2f, -25.8f, -30.4f, -6.6f, -11.2f, -13.8f, -14.4f }; 271*89c4ff92SAndroid Build Coastguard Worker 272*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 273*89c4ff92SAndroid Build Coastguard Worker 274*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 275*89c4ff92SAndroid Build Coastguard Worker { 276*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 277*89c4ff92SAndroid Build Coastguard Worker } 278*89c4ff92SAndroid Build Coastguard Worker } 279*89c4ff92SAndroid Build Coastguard Worker 280*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQSymmS8") 281*89c4ff92SAndroid Build Coastguard Worker { 282*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QSymmS8, 0.1f); 283*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 }; 284*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 12.0f, -12.6f, -11.6f, -10.6f, -9.6f, -8.6f, -7.6f, -6.6f, -5.6f, -4.6f, -3.6f }; 285*89c4ff92SAndroid Build Coastguard Worker 286*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 287*89c4ff92SAndroid Build Coastguard Worker 288*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 289*89c4ff92SAndroid Build Coastguard Worker { 290*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 291*89c4ff92SAndroid Build Coastguard Worker } 292*89c4ff92SAndroid Build Coastguard Worker } 293*89c4ff92SAndroid Build Coastguard Worker 294*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQAsymmS8PerAxis") 295*89c4ff92SAndroid Build Coastguard Worker { 296*89c4ff92SAndroid Build Coastguard Worker std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f }; 297*89c4ff92SAndroid Build Coastguard Worker unsigned int quantizationDim = 1; 298*89c4ff92SAndroid Build Coastguard Worker 299*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmS8, quantizationScales, quantizationDim); 300*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 }; 301*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 24.0f, -37.8f, -46.4f, -10.6f, -19.2f, -25.8f, -30.4f, -6.6f, -11.2f, -13.8f, -14.4f }; 302*89c4ff92SAndroid Build Coastguard Worker 303*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 304*89c4ff92SAndroid Build Coastguard Worker 305*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 306*89c4ff92SAndroid Build Coastguard Worker { 307*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 308*89c4ff92SAndroid Build Coastguard Worker } 309*89c4ff92SAndroid Build Coastguard Worker } 310*89c4ff92SAndroid Build Coastguard Worker 311*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQAsymmS8") 312*89c4ff92SAndroid Build Coastguard Worker { 313*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmS8, 0.1f); 314*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 }; 315*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 12.0f, -12.6f, -11.6f, -10.6f, -9.6f, -8.6f, -7.6f, -6.6f, -5.6f, -4.6f, -3.6f }; 316*89c4ff92SAndroid Build Coastguard Worker 317*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 318*89c4ff92SAndroid Build Coastguard Worker 319*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 320*89c4ff92SAndroid Build Coastguard Worker { 321*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 322*89c4ff92SAndroid Build Coastguard Worker } 323*89c4ff92SAndroid Build Coastguard Worker } 324*89c4ff92SAndroid Build Coastguard Worker 325*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQASymmU8PerAxis") 326*89c4ff92SAndroid Build Coastguard Worker { 327*89c4ff92SAndroid Build Coastguard Worker std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f }; 328*89c4ff92SAndroid Build Coastguard Worker unsigned int quantizationDim = 1; 329*89c4ff92SAndroid Build Coastguard Worker 330*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmU8, quantizationScales, quantizationDim); 331*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220 }; 332*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f }; 333*89c4ff92SAndroid Build Coastguard Worker 334*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 335*89c4ff92SAndroid Build Coastguard Worker 336*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 337*89c4ff92SAndroid Build Coastguard Worker { 338*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 339*89c4ff92SAndroid Build Coastguard Worker } 340*89c4ff92SAndroid Build Coastguard Worker } 341*89c4ff92SAndroid Build Coastguard Worker 342*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArrayQAsymmU8") 343*89c4ff92SAndroid Build Coastguard Worker { 344*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmU8, 0.1f); 345*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220 }; 346*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f }; 347*89c4ff92SAndroid Build Coastguard Worker 348*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 349*89c4ff92SAndroid Build Coastguard Worker 350*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 351*89c4ff92SAndroid Build Coastguard Worker { 352*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 353*89c4ff92SAndroid Build Coastguard Worker } 354*89c4ff92SAndroid Build Coastguard Worker } 355*89c4ff92SAndroid Build Coastguard Worker 356*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArraySigned32PerAxis") 357*89c4ff92SAndroid Build Coastguard Worker { 358*89c4ff92SAndroid Build Coastguard Worker std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f }; 359*89c4ff92SAndroid Build Coastguard Worker unsigned int quantizationDim = 1; 360*89c4ff92SAndroid Build Coastguard Worker 361*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed32, quantizationScales, quantizationDim); 362*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 0, 0, 0, 120, 0, 0, 0, 130, 0, 0, 0, 140, 0, 0, 0, 150, 0, 0, 0, 160, 0, 0, 0, 363*89c4ff92SAndroid Build Coastguard Worker 170, 0, 0, 0, 180, 0, 0, 0, 190, 0, 0, 0, 200, 0, 0, 0, 210, 0, 0, 0, 220, 0, 0, 0 }; 364*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f }; 365*89c4ff92SAndroid Build Coastguard Worker 366*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 367*89c4ff92SAndroid Build Coastguard Worker 368*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 369*89c4ff92SAndroid Build Coastguard Worker { 370*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 371*89c4ff92SAndroid Build Coastguard Worker } 372*89c4ff92SAndroid Build Coastguard Worker } 373*89c4ff92SAndroid Build Coastguard Worker 374*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArraySigned32") 375*89c4ff92SAndroid Build Coastguard Worker { 376*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed32, 0.1f); 377*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 0, 0, 0, 120, 0, 0, 0, 130, 0, 0, 0, 140, 0, 0, 0, 150, 0, 0, 0, 160, 0, 0, 0, 378*89c4ff92SAndroid Build Coastguard Worker 170, 0, 0, 0, 180, 0, 0, 0, 190, 0, 0, 0, 200, 0, 0, 0, 210, 0, 0, 0, 220, 0, 0, 0 }; 379*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f }; 380*89c4ff92SAndroid Build Coastguard Worker 381*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 382*89c4ff92SAndroid Build Coastguard Worker 383*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 384*89c4ff92SAndroid Build Coastguard Worker { 385*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 386*89c4ff92SAndroid Build Coastguard Worker } 387*89c4ff92SAndroid Build Coastguard Worker } 388*89c4ff92SAndroid Build Coastguard Worker 389*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArraySigned64PerAxis") 390*89c4ff92SAndroid Build Coastguard Worker { 391*89c4ff92SAndroid Build Coastguard Worker std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f }; 392*89c4ff92SAndroid Build Coastguard Worker unsigned int quantizationDim = 1; 393*89c4ff92SAndroid Build Coastguard Worker 394*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed64, quantizationScales, quantizationDim); 395*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 0, 0, 0, 0, 0, 0, 0, 120, 0, 0, 0, 0, 0, 0, 0, 130, 0, 0, 0, 0, 0, 0, 0, 396*89c4ff92SAndroid Build Coastguard Worker 140, 0, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0, 397*89c4ff92SAndroid Build Coastguard Worker 170, 0, 0, 0, 0, 0, 0, 0, 180, 0, 0, 0, 0, 0, 0, 0, 190, 0, 0, 0, 0, 0, 0, 0, 398*89c4ff92SAndroid Build Coastguard Worker 200, 0, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 0, 220, 0, 0, 0, 0, 0, 0, 0 }; 399*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f }; 400*89c4ff92SAndroid Build Coastguard Worker 401*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 402*89c4ff92SAndroid Build Coastguard Worker 403*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 404*89c4ff92SAndroid Build Coastguard Worker { 405*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 406*89c4ff92SAndroid Build Coastguard Worker } 407*89c4ff92SAndroid Build Coastguard Worker } 408*89c4ff92SAndroid Build Coastguard Worker 409*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ToFloatArraySigned64") 410*89c4ff92SAndroid Build Coastguard Worker { 411*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed64, 0.1f); 412*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> data { 100, 0, 0, 0, 0, 0, 0, 0, 120, 0, 0, 0, 0, 0, 0, 0, 130, 0, 0, 0, 0, 0, 0, 0, 413*89c4ff92SAndroid Build Coastguard Worker 140, 0, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0, 414*89c4ff92SAndroid Build Coastguard Worker 170, 0, 0, 0, 0, 0, 0, 0, 180, 0, 0, 0, 0, 0, 0, 0, 190, 0, 0, 0, 0, 0, 0, 0, 415*89c4ff92SAndroid Build Coastguard Worker 200, 0, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 0, 220, 0, 0, 0, 0, 0, 0, 0 }; 416*89c4ff92SAndroid Build Coastguard Worker float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f }; 417*89c4ff92SAndroid Build Coastguard Worker 418*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<float[]> result = ToFloatArray(data, info); 419*89c4ff92SAndroid Build Coastguard Worker 420*89c4ff92SAndroid Build Coastguard Worker for (uint i = 0; i < info.GetNumElements(); ++i) 421*89c4ff92SAndroid Build Coastguard Worker { 422*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(result[i], doctest::Approx(expected[i])); 423*89c4ff92SAndroid Build Coastguard Worker } 424*89c4ff92SAndroid Build Coastguard Worker } 425*89c4ff92SAndroid Build Coastguard Worker } 426