1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "../ParserHelper.hpp" 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp> 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp> 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 15*89c4ff92SAndroid Build Coastguard Worker using namespace armnnUtils; 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("ParserHelperSuite") 18*89c4ff92SAndroid Build Coastguard Worker { 19*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("CalculateReducedOutputTensoInfoTest") 20*89c4ff92SAndroid Build Coastguard Worker { 21*89c4ff92SAndroid Build Coastguard Worker bool keepDims = false; 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Worker unsigned int inputShape[] = { 2, 3, 4 }; 24*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(3, &inputShape[0], DataType::Float32); 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker // Reducing all dimensions results in one single output value (one dimension) 27*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> axisData1 = { 0, 1, 2 }; 28*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo1; 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker CalculateReducedOutputTensoInfo(inputTensorInfo, axisData1, keepDims, outputTensorInfo1); 31*89c4ff92SAndroid Build Coastguard Worker 32*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo1.GetNumDimensions() == 1); 33*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo1.GetShape()[0] == 1); 34*89c4ff92SAndroid Build Coastguard Worker 35*89c4ff92SAndroid Build Coastguard Worker // Reducing dimension 0 results in a 3x4 size tensor (one dimension) 36*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> axisData2 = { 0 }; 37*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo2; 38*89c4ff92SAndroid Build Coastguard Worker 39*89c4ff92SAndroid Build Coastguard Worker CalculateReducedOutputTensoInfo(inputTensorInfo, axisData2, keepDims, outputTensorInfo2); 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo2.GetNumDimensions() == 1); 42*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo2.GetShape()[0] == 12); 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Worker // Reducing dimensions 0,1 results in a 4 size tensor (one dimension) 45*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> axisData3 = { 0, 1 }; 46*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo3; 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker CalculateReducedOutputTensoInfo(inputTensorInfo, axisData3, keepDims, outputTensorInfo3); 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo3.GetNumDimensions() == 1); 51*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo3.GetShape()[0] == 4); 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Worker // Reducing dimension 0 results in a { 1, 3, 4 } dimension tensor 54*89c4ff92SAndroid Build Coastguard Worker keepDims = true; 55*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> axisData4 = { 0 }; 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo4; 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker CalculateReducedOutputTensoInfo(inputTensorInfo, axisData4, keepDims, outputTensorInfo4); 60*89c4ff92SAndroid Build Coastguard Worker 61*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo4.GetNumDimensions() == 3); 62*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo4.GetShape()[0] == 1); 63*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo4.GetShape()[1] == 3); 64*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo4.GetShape()[2] == 4); 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker // Reducing dimension 1, 2 results in a { 2, 1, 1 } dimension tensor 67*89c4ff92SAndroid Build Coastguard Worker keepDims = true; 68*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> axisData5 = { 1, 2 }; 69*89c4ff92SAndroid Build Coastguard Worker 70*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo5; 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker CalculateReducedOutputTensoInfo(inputTensorInfo, axisData5, keepDims, outputTensorInfo5); 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo5.GetNumDimensions() == 3); 75*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo5.GetShape()[0] == 2); 76*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo5.GetShape()[1] == 1); 77*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo5.GetShape()[2] == 1); 78*89c4ff92SAndroid Build Coastguard Worker 79*89c4ff92SAndroid Build Coastguard Worker } 80*89c4ff92SAndroid Build Coastguard Worker 81*89c4ff92SAndroid Build Coastguard Worker } 82*89c4ff92SAndroid Build Coastguard Worker 83