xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Squeeze.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 
9 TEST_SUITE("TensorflowLiteParser_Squeeze")
10 {
11 struct SqueezeFixture : public ParserFlatbuffersFixture
12 {
SqueezeFixtureSqueezeFixture13     explicit SqueezeFixture(const std::string& inputShape,
14                             const std::string& outputShape,
15                             const std::string& squeezeDims)
16     {
17         m_JsonString = R"(
18             {
19                 "version": 3,
20                 "operator_codes": [ { "builtin_code": "SQUEEZE" } ],
21                 "subgraphs": [ {
22                     "tensors": [
23                         {)";
24         m_JsonString += R"(
25                             "shape" : )" + inputShape + ",";
26         m_JsonString += R"(
27                             "type": "UINT8",
28                             "buffer": 0,
29                             "name": "inputTensor",
30                             "quantization": {
31                                 "min": [ 0.0 ],
32                                 "max": [ 255.0 ],
33                                 "scale": [ 1.0 ],
34                                 "zero_point": [ 0 ],
35                             }
36                         },
37                         {)";
38         m_JsonString += R"(
39                             "shape" : )" + outputShape;
40         m_JsonString += R"(,
41                             "type": "UINT8",
42                             "buffer": 1,
43                             "name": "outputTensor",
44                             "quantization": {
45                                 "min": [ 0.0 ],
46                                 "max": [ 255.0 ],
47                                 "scale": [ 1.0 ],
48                                 "zero_point": [ 0 ],
49                             }
50                         }
51                     ],
52                     "inputs": [ 0 ],
53                     "outputs": [ 1 ],
54                     "operators": [
55                         {
56                             "opcode_index": 0,
57                             "inputs": [ 0 ],
58                             "outputs": [ 1 ],
59                             "builtin_options_type": "SqueezeOptions",
60                             "builtin_options": {)";
61         if (!squeezeDims.empty())
62         {
63             m_JsonString += R"("squeeze_dims" : )" + squeezeDims;
64         }
65         m_JsonString += R"(},
66                             "custom_options_format": "FLEXBUFFERS"
67                         }
68                     ],
69                 } ],
70                 "buffers" : [ {}, {} ]
71             }
72         )";
73     }
74 };
75 
76 struct SqueezeFixtureWithSqueezeDims : SqueezeFixture
77 {
SqueezeFixtureWithSqueezeDimsSqueezeFixtureWithSqueezeDims78     SqueezeFixtureWithSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2, 1 ]", "[ 0, 1, 2 ]") {}
79 };
80 
81 TEST_CASE_FIXTURE(SqueezeFixtureWithSqueezeDims, "ParseSqueezeWithSqueezeDims")
82 {
83     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
84     RunTest<3, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
85     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
86         == armnn::TensorShape({2,2,1})));
87 
88 }
89 
90 struct SqueezeFixtureWithoutSqueezeDims : SqueezeFixture
91 {
SqueezeFixtureWithoutSqueezeDimsSqueezeFixtureWithoutSqueezeDims92     SqueezeFixtureWithoutSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2 ]", "") {}
93 };
94 
95 TEST_CASE_FIXTURE(SqueezeFixtureWithoutSqueezeDims, "ParseSqueezeWithoutSqueezeDims")
96 {
97     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
98     RunTest<2, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
99     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
100         == armnn::TensorShape({2,2})));
101 }
102 
103 struct SqueezeFixtureWithInvalidInput : SqueezeFixture
104 {
SqueezeFixtureWithInvalidInputSqueezeFixtureWithInvalidInput105     SqueezeFixtureWithInvalidInput() : SqueezeFixture("[ 1, 2, 2, 1, 2, 2 ]", "[ 1, 2, 2, 1, 2 ]", "[ ]") {}
106 };
107 
108 TEST_CASE_FIXTURE(SqueezeFixtureWithInvalidInput, "ParseSqueezeInvalidInput")
109 {
110     static_assert(armnn::MaxNumOfTensorDimensions == 5, "Please update SqueezeFixtureWithInvalidInput");
111     CHECK_THROWS_AS((SetupSingleInputSingleOutput("inputTensor", "outputTensor")),
112                       armnn::InvalidArgumentException);
113 }
114 
115 struct SqueezeFixtureWithSqueezeDimsSizeInvalid : SqueezeFixture
116 {
SqueezeFixtureWithSqueezeDimsSizeInvalidSqueezeFixtureWithSqueezeDimsSizeInvalid117     SqueezeFixtureWithSqueezeDimsSizeInvalid() : SqueezeFixture("[ 1, 2, 2, 1 ]",
118                                                                 "[ 1, 2, 2, 1 ]",
119                                                                 "[ 1, 2, 2, 2, 2 ]") {}
120 };
121 
122 TEST_CASE_FIXTURE(SqueezeFixtureWithSqueezeDimsSizeInvalid, "ParseSqueezeInvalidSqueezeDims")
123 {
124     CHECK_THROWS_AS((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
125 }
126 
127 
128 struct SqueezeFixtureWithNegativeSqueezeDims1 : SqueezeFixture
129 {
SqueezeFixtureWithNegativeSqueezeDims1SqueezeFixtureWithNegativeSqueezeDims1130     SqueezeFixtureWithNegativeSqueezeDims1() : SqueezeFixture("[ 1, 2, 2, 1 ]",
131                                                              "[ 2, 2, 1 ]",
132                                                              "[ -1 ]") {}
133 };
134 
135 TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDims1, "ParseSqueezeNegativeSqueezeDims1")
136 {
137     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
138     RunTest<3, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
139             CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
140                    == armnn::TensorShape({ 2, 2, 1 })));
141 }
142 
143 struct SqueezeFixtureWithNegativeSqueezeDims2 : SqueezeFixture
144 {
SqueezeFixtureWithNegativeSqueezeDims2SqueezeFixtureWithNegativeSqueezeDims2145     SqueezeFixtureWithNegativeSqueezeDims2() : SqueezeFixture("[ 1, 2, 2, 1 ]",
146                                                               "[ 1, 2, 2 ]",
147                                                               "[ -1 ]") {}
148 };
149 
150 TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDims2, "ParseSqueezeNegativeSqueezeDims2")
151 {
152     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
153     RunTest<3, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
154             CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
155                    == armnn::TensorShape({ 1, 2, 2 })));
156 }
157 
158 struct SqueezeFixtureWithNegativeSqueezeDimsInvalid : SqueezeFixture
159 {
SqueezeFixtureWithNegativeSqueezeDimsInvalidSqueezeFixtureWithNegativeSqueezeDimsInvalid160     SqueezeFixtureWithNegativeSqueezeDimsInvalid() : SqueezeFixture("[ 1, 2, 2, 1 ]",
161                                                                     "[ 1, 2, 2, 1 ]",
162                                                                     "[ -2 , 2 ]") {}
163 };
164 
165 TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDimsInvalid, "ParseSqueezeNegativeSqueezeDimsInvalid")
166 {
167     CHECK_THROWS_AS((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
168 }
169 
170 
171 }
172