xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "../TfLiteParser.hpp"
7 
8 #include <doctest/doctest.h>
9 
10 TEST_SUITE("TensorflowLiteParser_OutputShapeOfSqueeze")
11 {
12 
13 struct TfLiteParserFixture
14 {
15 
16     armnnTfLiteParser::TfLiteParserImpl m_Parser;
17     unsigned int m_InputShape[4];
18 
TfLiteParserFixtureTfLiteParserFixture19     TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {}
~TfLiteParserFixtureTfLiteParserFixture20     ~TfLiteParserFixture()          {  }
21 
22 };
23 
24 TEST_CASE_FIXTURE(TfLiteParserFixture, "EmptySqueezeDims_OutputWithAllDimensionsSqueezed")
25 {
26 
27     std::vector<uint32_t> squeezeDims = {  };
28 
29     armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
30     armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
31     CHECK(outputTensorInfo.GetNumElements() == 4);
32     CHECK(outputTensorInfo.GetNumDimensions() == 2);
33     CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
34 };
35 
36 TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput")
37 {
38     std::vector<uint32_t> squeezeDims = { 1, 2 };
39 
40     armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
41     armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
42     CHECK(outputTensorInfo.GetNumElements() == 4);
43     CHECK(outputTensorInfo.GetNumDimensions() == 4);
44     CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
45 };
46 
47 TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed")
48 {
49     std::vector<uint32_t> squeezeDims = { 1, 3 };
50 
51     armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
52     armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
53     CHECK(outputTensorInfo.GetNumElements() == 4);
54     CHECK(outputTensorInfo.GetNumDimensions() == 3);
55     CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
56 };
57 
58 }