xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Unpack.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Unpack")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct UnpackFixture : public ParserFlatbuffersFixture
12*89c4ff92SAndroid Build Coastguard Worker {
UnpackFixtureUnpackFixture13*89c4ff92SAndroid Build Coastguard Worker     explicit UnpackFixture(const std::string& inputShape,
14*89c4ff92SAndroid Build Coastguard Worker                            const unsigned int numberOfOutputs,
15*89c4ff92SAndroid Build Coastguard Worker                            const std::string& outputShape,
16*89c4ff92SAndroid Build Coastguard Worker                            const std::string& axis,
17*89c4ff92SAndroid Build Coastguard Worker                            const std::string& num,
18*89c4ff92SAndroid Build Coastguard Worker                            const std::string& dataType,
19*89c4ff92SAndroid Build Coastguard Worker                            const std::string& outputScale,
20*89c4ff92SAndroid Build Coastguard Worker                            const std::string& outputOffset)
21*89c4ff92SAndroid Build Coastguard Worker     {
22*89c4ff92SAndroid Build Coastguard Worker         // As input index is 0, output indexes start at 1
23*89c4ff92SAndroid Build Coastguard Worker         std::string outputIndexes = "1";
24*89c4ff92SAndroid Build Coastguard Worker         for(unsigned int i = 1; i < numberOfOutputs; i++)
25*89c4ff92SAndroid Build Coastguard Worker         {
26*89c4ff92SAndroid Build Coastguard Worker             outputIndexes += ", " + std::to_string(i+1);
27*89c4ff92SAndroid Build Coastguard Worker         }
28*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
29*89c4ff92SAndroid Build Coastguard Worker             {
30*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
31*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "UNPACK" } ],
32*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {
33*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
34*89c4ff92SAndroid Build Coastguard Worker                         {
35*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + inputShape + R"(,
36*89c4ff92SAndroid Build Coastguard Worker                             "type": )" + dataType + R"(,
37*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
38*89c4ff92SAndroid Build Coastguard Worker                             "name": "inputTensor",
39*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
40*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
41*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
42*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
43*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
44*89c4ff92SAndroid Build Coastguard Worker                             }
45*89c4ff92SAndroid Build Coastguard Worker                         },)";
46*89c4ff92SAndroid Build Coastguard Worker         // Append the required number of outputs for this UnpackFixture.
47*89c4ff92SAndroid Build Coastguard Worker         // As input index is 0, output indexes start at 1.
48*89c4ff92SAndroid Build Coastguard Worker         for(unsigned int i = 0; i < numberOfOutputs; i++)
49*89c4ff92SAndroid Build Coastguard Worker         {
50*89c4ff92SAndroid Build Coastguard Worker             m_JsonString += R"(
51*89c4ff92SAndroid Build Coastguard Worker                         {
52*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + outputShape + R"( ,
53*89c4ff92SAndroid Build Coastguard Worker                                 "type": )" + dataType + R"(,
54*89c4ff92SAndroid Build Coastguard Worker                                 "buffer": )" + std::to_string(i + 1) + R"(,
55*89c4ff92SAndroid Build Coastguard Worker                                 "name": "outputTensor)" + std::to_string(i + 1) + R"(",
56*89c4ff92SAndroid Build Coastguard Worker                                 "quantization": {
57*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
58*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
59*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ )" + outputScale + R"( ],
60*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ )" + outputOffset + R"( ],
61*89c4ff92SAndroid Build Coastguard Worker                             }
62*89c4ff92SAndroid Build Coastguard Worker                         },)";
63*89c4ff92SAndroid Build Coastguard Worker         }
64*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(
65*89c4ff92SAndroid Build Coastguard Worker                     ],
66*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0 ],
67*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ )" + outputIndexes + R"( ],
68*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
69*89c4ff92SAndroid Build Coastguard Worker                         {
70*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 0,
71*89c4ff92SAndroid Build Coastguard Worker                             "inputs": [ 0 ],
72*89c4ff92SAndroid Build Coastguard Worker                             "outputs": [ )" + outputIndexes + R"( ],
73*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options_type": "UnpackOptions",
74*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options": {
75*89c4ff92SAndroid Build Coastguard Worker                                 "axis": )" + axis;
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker                     if(!num.empty())
78*89c4ff92SAndroid Build Coastguard Worker                     {
79*89c4ff92SAndroid Build Coastguard Worker                         m_JsonString += R"(,
80*89c4ff92SAndroid Build Coastguard Worker                                 "num" : )" + num;
81*89c4ff92SAndroid Build Coastguard Worker                     }
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker                     m_JsonString += R"(
84*89c4ff92SAndroid Build Coastguard Worker                             },
85*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
86*89c4ff92SAndroid Build Coastguard Worker                         }
87*89c4ff92SAndroid Build Coastguard Worker                     ],
88*89c4ff92SAndroid Build Coastguard Worker                 } ],
89*89c4ff92SAndroid Build Coastguard Worker                 "buffers" : [
90*89c4ff92SAndroid Build Coastguard Worker                     { },
91*89c4ff92SAndroid Build Coastguard Worker                     { }
92*89c4ff92SAndroid Build Coastguard Worker                 ]
93*89c4ff92SAndroid Build Coastguard Worker             }
94*89c4ff92SAndroid Build Coastguard Worker         )";
95*89c4ff92SAndroid Build Coastguard Worker         Setup();
96*89c4ff92SAndroid Build Coastguard Worker     }
97*89c4ff92SAndroid Build Coastguard Worker };
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker struct DefaultUnpackAxisZeroFixture : UnpackFixture
100*89c4ff92SAndroid Build Coastguard Worker {
DefaultUnpackAxisZeroFixtureDefaultUnpackAxisZeroFixture101*89c4ff92SAndroid Build Coastguard Worker     DefaultUnpackAxisZeroFixture() : UnpackFixture("[ 4, 1, 6 ]", 4, "[ 1, 6 ]", "0", "", "FLOAT32", "1.0", "0") {}
102*89c4ff92SAndroid Build Coastguard Worker };
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker struct DefaultUnpackAxisZeroUint8Fixture : UnpackFixture
105*89c4ff92SAndroid Build Coastguard Worker {
DefaultUnpackAxisZeroUint8FixtureDefaultUnpackAxisZeroUint8Fixture106*89c4ff92SAndroid Build Coastguard Worker     DefaultUnpackAxisZeroUint8Fixture() : UnpackFixture("[ 4, 1, 6 ]", 4, "[ 1, 6 ]", "0", "", "UINT8", "0.1", "0") {}
107*89c4ff92SAndroid Build Coastguard Worker };
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DefaultUnpackAxisZeroFixture, "UnpackAxisZeroNumIsDefaultNotSpecified")
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(
112*89c4ff92SAndroid Build Coastguard Worker         0,
113*89c4ff92SAndroid Build Coastguard Worker         { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
114*89c4ff92SAndroid Build Coastguard Worker                             7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
115*89c4ff92SAndroid Build Coastguard Worker                             13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
116*89c4ff92SAndroid Build Coastguard Worker                             19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f } } },
117*89c4ff92SAndroid Build Coastguard Worker         { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }},
118*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor2", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
119*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor3", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }},
120*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor4", { 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }} });
121*89c4ff92SAndroid Build Coastguard Worker }
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DefaultUnpackAxisZeroUint8Fixture, "UnpackAxisZeroNumIsDefaultNotSpecifiedUint8")
124*89c4ff92SAndroid Build Coastguard Worker {
125*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(
126*89c4ff92SAndroid Build Coastguard Worker         0,
127*89c4ff92SAndroid Build Coastguard Worker         { {"inputTensor", { 1, 2, 3, 4, 5, 6,
128*89c4ff92SAndroid Build Coastguard Worker                             7, 8, 9, 10, 11, 12,
129*89c4ff92SAndroid Build Coastguard Worker                             13, 14, 15, 16, 17, 18,
130*89c4ff92SAndroid Build Coastguard Worker                             19, 20, 21, 22, 23, 24 } } },
131*89c4ff92SAndroid Build Coastguard Worker         { {"outputTensor1", { 10, 20, 30, 40, 50, 60 }},
132*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor2", { 70, 80, 90, 100, 110, 120 }},
133*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor3", { 130, 140, 150, 160, 170, 180 }},
134*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor4", { 190, 200, 210, 220, 230, 240 }} });
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker struct DefaultUnpackLastAxisFixture : UnpackFixture
138*89c4ff92SAndroid Build Coastguard Worker {
DefaultUnpackLastAxisFixtureDefaultUnpackLastAxisFixture139*89c4ff92SAndroid Build Coastguard Worker     DefaultUnpackLastAxisFixture() : UnpackFixture("[ 4, 1, 6 ]", 6, "[ 4, 1 ]", "2", "6", "FLOAT32", "1.0", "0") {}
140*89c4ff92SAndroid Build Coastguard Worker };
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker struct DefaultUnpackLastAxisUint8Fixture : UnpackFixture
143*89c4ff92SAndroid Build Coastguard Worker {
DefaultUnpackLastAxisUint8FixtureDefaultUnpackLastAxisUint8Fixture144*89c4ff92SAndroid Build Coastguard Worker     DefaultUnpackLastAxisUint8Fixture() : UnpackFixture("[ 4, 1, 6 ]", 6, "[ 4, 1 ]", "2", "6", "UINT8", "0.1", "0") {}
145*89c4ff92SAndroid Build Coastguard Worker };
146*89c4ff92SAndroid Build Coastguard Worker 
147*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DefaultUnpackLastAxisFixture, "UnpackLastAxisNumSix")
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::Float32>(
150*89c4ff92SAndroid Build Coastguard Worker         0,
151*89c4ff92SAndroid Build Coastguard Worker         { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
152*89c4ff92SAndroid Build Coastguard Worker                             7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
153*89c4ff92SAndroid Build Coastguard Worker                             13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
154*89c4ff92SAndroid Build Coastguard Worker                             19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f } } },
155*89c4ff92SAndroid Build Coastguard Worker         { {"outputTensor1", { 1.0f, 7.0f, 13.0f, 19.0f }},
156*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor2", { 2.0f, 8.0f, 14.0f, 20.0f }},
157*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor3", { 3.0f, 9.0f, 15.0f, 21.0f }},
158*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor4", { 4.0f, 10.0f, 16.0f, 22.0f }},
159*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor5", { 5.0f, 11.0f, 17.0f, 23.0f }},
160*89c4ff92SAndroid Build Coastguard Worker           {"outputTensor6", { 6.0f, 12.0f, 18.0f, 24.0f }} });
161*89c4ff92SAndroid Build Coastguard Worker }
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DefaultUnpackLastAxisUint8Fixture, "UnpackLastAxisNumSixUint8") {
164*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(
165*89c4ff92SAndroid Build Coastguard Worker         0,
166*89c4ff92SAndroid Build Coastguard Worker         {{"inputTensor", { 1, 2, 3, 4, 5, 6,
167*89c4ff92SAndroid Build Coastguard Worker                            7, 8, 9, 10, 11, 12,
168*89c4ff92SAndroid Build Coastguard Worker                            13, 14, 15, 16, 17, 18,
169*89c4ff92SAndroid Build Coastguard Worker                            19, 20, 21, 22, 23, 24 }}},
170*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor1", { 10, 70, 130, 190 }},
171*89c4ff92SAndroid Build Coastguard Worker          {"outputTensor2", { 20, 80, 140, 200 }},
172*89c4ff92SAndroid Build Coastguard Worker          {"outputTensor3", { 30, 90, 150, 210 }},
173*89c4ff92SAndroid Build Coastguard Worker          {"outputTensor4", { 40, 100, 160, 220 }},
174*89c4ff92SAndroid Build Coastguard Worker          {"outputTensor5", { 50, 110, 170, 230 }},
175*89c4ff92SAndroid Build Coastguard Worker          {"outputTensor6", { 60, 120, 180, 240 }}});
176*89c4ff92SAndroid Build Coastguard Worker }
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker }
179