xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Squeeze.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Squeeze")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct SqueezeFixture : public ParserFlatbuffersFixture
12*89c4ff92SAndroid Build Coastguard Worker {
SqueezeFixtureSqueezeFixture13*89c4ff92SAndroid Build Coastguard Worker     explicit SqueezeFixture(const std::string& inputShape,
14*89c4ff92SAndroid Build Coastguard Worker                             const std::string& outputShape,
15*89c4ff92SAndroid Build Coastguard Worker                             const std::string& squeezeDims)
16*89c4ff92SAndroid Build Coastguard Worker     {
17*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
18*89c4ff92SAndroid Build Coastguard Worker             {
19*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
20*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "SQUEEZE" } ],
21*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {
22*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
23*89c4ff92SAndroid Build Coastguard Worker                         {)";
24*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(
25*89c4ff92SAndroid Build Coastguard Worker                             "shape" : )" + inputShape + ",";
26*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(
27*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
28*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
29*89c4ff92SAndroid Build Coastguard Worker                             "name": "inputTensor",
30*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
31*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
32*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
33*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
34*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
35*89c4ff92SAndroid Build Coastguard Worker                             }
36*89c4ff92SAndroid Build Coastguard Worker                         },
37*89c4ff92SAndroid Build Coastguard Worker                         {)";
38*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(
39*89c4ff92SAndroid Build Coastguard Worker                             "shape" : )" + outputShape;
40*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(,
41*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
42*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
43*89c4ff92SAndroid Build Coastguard Worker                             "name": "outputTensor",
44*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
45*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
46*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
47*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
48*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
49*89c4ff92SAndroid Build Coastguard Worker                             }
50*89c4ff92SAndroid Build Coastguard Worker                         }
51*89c4ff92SAndroid Build Coastguard Worker                     ],
52*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0 ],
53*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ 1 ],
54*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
55*89c4ff92SAndroid Build Coastguard Worker                         {
56*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 0,
57*89c4ff92SAndroid Build Coastguard Worker                             "inputs": [ 0 ],
58*89c4ff92SAndroid Build Coastguard Worker                             "outputs": [ 1 ],
59*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options_type": "SqueezeOptions",
60*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options": {)";
61*89c4ff92SAndroid Build Coastguard Worker         if (!squeezeDims.empty())
62*89c4ff92SAndroid Build Coastguard Worker         {
63*89c4ff92SAndroid Build Coastguard Worker             m_JsonString += R"("squeeze_dims" : )" + squeezeDims;
64*89c4ff92SAndroid Build Coastguard Worker         }
65*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(},
66*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
67*89c4ff92SAndroid Build Coastguard Worker                         }
68*89c4ff92SAndroid Build Coastguard Worker                     ],
69*89c4ff92SAndroid Build Coastguard Worker                 } ],
70*89c4ff92SAndroid Build Coastguard Worker                 "buffers" : [ {}, {} ]
71*89c4ff92SAndroid Build Coastguard Worker             }
72*89c4ff92SAndroid Build Coastguard Worker         )";
73*89c4ff92SAndroid Build Coastguard Worker     }
74*89c4ff92SAndroid Build Coastguard Worker };
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker struct SqueezeFixtureWithSqueezeDims : SqueezeFixture
77*89c4ff92SAndroid Build Coastguard Worker {
SqueezeFixtureWithSqueezeDimsSqueezeFixtureWithSqueezeDims78*89c4ff92SAndroid Build Coastguard Worker     SqueezeFixtureWithSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2, 1 ]", "[ 0, 1, 2 ]") {}
79*89c4ff92SAndroid Build Coastguard Worker };
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SqueezeFixtureWithSqueezeDims, "ParseSqueezeWithSqueezeDims")
82*89c4ff92SAndroid Build Coastguard Worker {
83*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
84*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
85*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
86*89c4ff92SAndroid Build Coastguard Worker         == armnn::TensorShape({2,2,1})));
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker }
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker struct SqueezeFixtureWithoutSqueezeDims : SqueezeFixture
91*89c4ff92SAndroid Build Coastguard Worker {
SqueezeFixtureWithoutSqueezeDimsSqueezeFixtureWithoutSqueezeDims92*89c4ff92SAndroid Build Coastguard Worker     SqueezeFixtureWithoutSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2 ]", "") {}
93*89c4ff92SAndroid Build Coastguard Worker };
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SqueezeFixtureWithoutSqueezeDims, "ParseSqueezeWithoutSqueezeDims")
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
98*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
99*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
100*89c4ff92SAndroid Build Coastguard Worker         == armnn::TensorShape({2,2})));
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker struct SqueezeFixtureWithInvalidInput : SqueezeFixture
104*89c4ff92SAndroid Build Coastguard Worker {
SqueezeFixtureWithInvalidInputSqueezeFixtureWithInvalidInput105*89c4ff92SAndroid Build Coastguard Worker     SqueezeFixtureWithInvalidInput() : SqueezeFixture("[ 1, 2, 2, 1, 2, 2 ]", "[ 1, 2, 2, 1, 2 ]", "[ ]") {}
106*89c4ff92SAndroid Build Coastguard Worker };
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SqueezeFixtureWithInvalidInput, "ParseSqueezeInvalidInput")
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker     static_assert(armnn::MaxNumOfTensorDimensions == 5, "Please update SqueezeFixtureWithInvalidInput");
111*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS((SetupSingleInputSingleOutput("inputTensor", "outputTensor")),
112*89c4ff92SAndroid Build Coastguard Worker                       armnn::InvalidArgumentException);
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker struct SqueezeFixtureWithSqueezeDimsSizeInvalid : SqueezeFixture
116*89c4ff92SAndroid Build Coastguard Worker {
SqueezeFixtureWithSqueezeDimsSizeInvalidSqueezeFixtureWithSqueezeDimsSizeInvalid117*89c4ff92SAndroid Build Coastguard Worker     SqueezeFixtureWithSqueezeDimsSizeInvalid() : SqueezeFixture("[ 1, 2, 2, 1 ]",
118*89c4ff92SAndroid Build Coastguard Worker                                                                 "[ 1, 2, 2, 1 ]",
119*89c4ff92SAndroid Build Coastguard Worker                                                                 "[ 1, 2, 2, 2, 2 ]") {}
120*89c4ff92SAndroid Build Coastguard Worker };
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SqueezeFixtureWithSqueezeDimsSizeInvalid, "ParseSqueezeInvalidSqueezeDims")
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker struct SqueezeFixtureWithNegativeSqueezeDims1 : SqueezeFixture
129*89c4ff92SAndroid Build Coastguard Worker {
SqueezeFixtureWithNegativeSqueezeDims1SqueezeFixtureWithNegativeSqueezeDims1130*89c4ff92SAndroid Build Coastguard Worker     SqueezeFixtureWithNegativeSqueezeDims1() : SqueezeFixture("[ 1, 2, 2, 1 ]",
131*89c4ff92SAndroid Build Coastguard Worker                                                              "[ 2, 2, 1 ]",
132*89c4ff92SAndroid Build Coastguard Worker                                                              "[ -1 ]") {}
133*89c4ff92SAndroid Build Coastguard Worker };
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDims1, "ParseSqueezeNegativeSqueezeDims1")
136*89c4ff92SAndroid Build Coastguard Worker {
137*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
138*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
139*89c4ff92SAndroid Build Coastguard Worker             CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
140*89c4ff92SAndroid Build Coastguard Worker                    == armnn::TensorShape({ 2, 2, 1 })));
141*89c4ff92SAndroid Build Coastguard Worker }
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker struct SqueezeFixtureWithNegativeSqueezeDims2 : SqueezeFixture
144*89c4ff92SAndroid Build Coastguard Worker {
SqueezeFixtureWithNegativeSqueezeDims2SqueezeFixtureWithNegativeSqueezeDims2145*89c4ff92SAndroid Build Coastguard Worker     SqueezeFixtureWithNegativeSqueezeDims2() : SqueezeFixture("[ 1, 2, 2, 1 ]",
146*89c4ff92SAndroid Build Coastguard Worker                                                               "[ 1, 2, 2 ]",
147*89c4ff92SAndroid Build Coastguard Worker                                                               "[ -1 ]") {}
148*89c4ff92SAndroid Build Coastguard Worker };
149*89c4ff92SAndroid Build Coastguard Worker 
150*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDims2, "ParseSqueezeNegativeSqueezeDims2")
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
153*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::QAsymmU8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
154*89c4ff92SAndroid Build Coastguard Worker             CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
155*89c4ff92SAndroid Build Coastguard Worker                    == armnn::TensorShape({ 1, 2, 2 })));
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker struct SqueezeFixtureWithNegativeSqueezeDimsInvalid : SqueezeFixture
159*89c4ff92SAndroid Build Coastguard Worker {
SqueezeFixtureWithNegativeSqueezeDimsInvalidSqueezeFixtureWithNegativeSqueezeDimsInvalid160*89c4ff92SAndroid Build Coastguard Worker     SqueezeFixtureWithNegativeSqueezeDimsInvalid() : SqueezeFixture("[ 1, 2, 2, 1 ]",
161*89c4ff92SAndroid Build Coastguard Worker                                                                     "[ 1, 2, 2, 1 ]",
162*89c4ff92SAndroid Build Coastguard Worker                                                                     "[ -2 , 2 ]") {}
163*89c4ff92SAndroid Build Coastguard Worker };
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SqueezeFixtureWithNegativeSqueezeDimsInvalid, "ParseSqueezeNegativeSqueezeDimsInvalid")
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker }
172