xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Concatenation.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Concatenation")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct ConcatenationFixture : public ParserFlatbuffersFixture
12*89c4ff92SAndroid Build Coastguard Worker {
ConcatenationFixtureConcatenationFixture13*89c4ff92SAndroid Build Coastguard Worker     explicit ConcatenationFixture(const std::string & inputShape1,
14*89c4ff92SAndroid Build Coastguard Worker                                   const std::string & inputShape2,
15*89c4ff92SAndroid Build Coastguard Worker                                   const std::string & outputShape,
16*89c4ff92SAndroid Build Coastguard Worker                                   const std::string & axis,
17*89c4ff92SAndroid Build Coastguard Worker                                   const std::string & activation="NONE")
18*89c4ff92SAndroid Build Coastguard Worker     {
19*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
20*89c4ff92SAndroid Build Coastguard Worker             {
21*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
22*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "CONCATENATION" } ],
23*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {
24*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
25*89c4ff92SAndroid Build Coastguard Worker                         {
26*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + inputShape1 + R"(,
27*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
28*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
29*89c4ff92SAndroid Build Coastguard Worker                             "name": "inputTensor1",
30*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
31*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
32*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
33*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
34*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
35*89c4ff92SAndroid Build Coastguard Worker                             }
36*89c4ff92SAndroid Build Coastguard Worker                         },
37*89c4ff92SAndroid Build Coastguard Worker                         {
38*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + inputShape2 + R"(,
39*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
40*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
41*89c4ff92SAndroid Build Coastguard Worker                             "name": "inputTensor2",
42*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
43*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
44*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
45*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
46*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
47*89c4ff92SAndroid Build Coastguard Worker                             }
48*89c4ff92SAndroid Build Coastguard Worker                         },
49*89c4ff92SAndroid Build Coastguard Worker                         {
50*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + outputShape + R"( ,
51*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
52*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 2,
53*89c4ff92SAndroid Build Coastguard Worker                             "name": "outputTensor",
54*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
55*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
56*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
57*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
58*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
59*89c4ff92SAndroid Build Coastguard Worker                             }
60*89c4ff92SAndroid Build Coastguard Worker                         }
61*89c4ff92SAndroid Build Coastguard Worker                     ],
62*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0, 1 ],
63*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ 2 ],
64*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
65*89c4ff92SAndroid Build Coastguard Worker                         {
66*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 0,
67*89c4ff92SAndroid Build Coastguard Worker                             "inputs": [ 0, 1 ],
68*89c4ff92SAndroid Build Coastguard Worker                             "outputs": [ 2 ],
69*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options_type": "ConcatenationOptions",
70*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options": {
71*89c4ff92SAndroid Build Coastguard Worker                                 "axis": )" + axis + R"(,
72*89c4ff92SAndroid Build Coastguard Worker                                 "fused_activation_function": )" + activation + R"(
73*89c4ff92SAndroid Build Coastguard Worker                             },
74*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
75*89c4ff92SAndroid Build Coastguard Worker                         }
76*89c4ff92SAndroid Build Coastguard Worker                     ],
77*89c4ff92SAndroid Build Coastguard Worker                 } ],
78*89c4ff92SAndroid Build Coastguard Worker                 "buffers" : [
79*89c4ff92SAndroid Build Coastguard Worker                     { },
80*89c4ff92SAndroid Build Coastguard Worker                     { }
81*89c4ff92SAndroid Build Coastguard Worker                 ]
82*89c4ff92SAndroid Build Coastguard Worker             }
83*89c4ff92SAndroid Build Coastguard Worker         )";
84*89c4ff92SAndroid Build Coastguard Worker         Setup();
85*89c4ff92SAndroid Build Coastguard Worker     }
86*89c4ff92SAndroid Build Coastguard Worker };
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker struct ConcatenationFixtureNegativeDim : ConcatenationFixture
90*89c4ff92SAndroid Build Coastguard Worker {
ConcatenationFixtureNegativeDimConcatenationFixtureNegativeDim91*89c4ff92SAndroid Build Coastguard Worker     ConcatenationFixtureNegativeDim() : ConcatenationFixture("[ 1, 1, 2, 2 ]",
92*89c4ff92SAndroid Build Coastguard Worker                                                              "[ 1, 1, 2, 2 ]",
93*89c4ff92SAndroid Build Coastguard Worker                                                              "[ 1, 2, 2, 2 ]",
94*89c4ff92SAndroid Build Coastguard Worker                                                              "-3" ) {}
95*89c4ff92SAndroid Build Coastguard Worker };
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ConcatenationFixtureNegativeDim, "ParseConcatenationNegativeDim")
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(
100*89c4ff92SAndroid Build Coastguard Worker         0,
101*89c4ff92SAndroid Build Coastguard Worker         {{"inputTensor1", { 0, 1, 2, 3 }},
102*89c4ff92SAndroid Build Coastguard Worker         {"inputTensor2", { 4, 5, 6, 7 }}},
103*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker struct ConcatenationFixtureNCHW : ConcatenationFixture
107*89c4ff92SAndroid Build Coastguard Worker {
ConcatenationFixtureNCHWConcatenationFixtureNCHW108*89c4ff92SAndroid Build Coastguard Worker     ConcatenationFixtureNCHW() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 2, 2, 2 ]", "1" ) {}
109*89c4ff92SAndroid Build Coastguard Worker };
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ConcatenationFixtureNCHW, "ParseConcatenationNCHW")
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(
114*89c4ff92SAndroid Build Coastguard Worker         0,
115*89c4ff92SAndroid Build Coastguard Worker         {{"inputTensor1", { 0, 1, 2, 3 }},
116*89c4ff92SAndroid Build Coastguard Worker         {"inputTensor2", { 4, 5, 6, 7 }}},
117*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker struct ConcatenationFixtureNHWC : ConcatenationFixture
121*89c4ff92SAndroid Build Coastguard Worker {
ConcatenationFixtureNHWCConcatenationFixtureNHWC122*89c4ff92SAndroid Build Coastguard Worker     ConcatenationFixtureNHWC() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 4 ]", "3" ) {}
123*89c4ff92SAndroid Build Coastguard Worker };
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ConcatenationFixtureNHWC, "ParseConcatenationNHWC")
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(
128*89c4ff92SAndroid Build Coastguard Worker         0,
129*89c4ff92SAndroid Build Coastguard Worker         {{"inputTensor1", { 0, 1, 2, 3 }},
130*89c4ff92SAndroid Build Coastguard Worker         {"inputTensor2", { 4, 5, 6, 7 }}},
131*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor", { 0, 1, 4, 5, 2, 3, 6, 7 }}});
132*89c4ff92SAndroid Build Coastguard Worker }
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker struct ConcatenationFixtureDim1 : ConcatenationFixture
135*89c4ff92SAndroid Build Coastguard Worker {
ConcatenationFixtureDim1ConcatenationFixtureDim1136*89c4ff92SAndroid Build Coastguard Worker     ConcatenationFixtureDim1() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 4, 3, 4 ]", "1" ) {}
137*89c4ff92SAndroid Build Coastguard Worker };
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ConcatenationFixtureDim1, "ParseConcatenationDim1")
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(
142*89c4ff92SAndroid Build Coastguard Worker         0,
143*89c4ff92SAndroid Build Coastguard Worker         { { "inputTensor1", {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
144*89c4ff92SAndroid Build Coastguard Worker                                12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 } },
145*89c4ff92SAndroid Build Coastguard Worker         { "inputTensor2", {  50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
146*89c4ff92SAndroid Build Coastguard Worker                              62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } },
147*89c4ff92SAndroid Build Coastguard Worker         { { "outputTensor", {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
148*89c4ff92SAndroid Build Coastguard Worker                                12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
149*89c4ff92SAndroid Build Coastguard Worker                                50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
150*89c4ff92SAndroid Build Coastguard Worker                                62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } });
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker struct ConcatenationFixtureDim3 : ConcatenationFixture
154*89c4ff92SAndroid Build Coastguard Worker {
ConcatenationFixtureDim3ConcatenationFixtureDim3155*89c4ff92SAndroid Build Coastguard Worker     ConcatenationFixtureDim3() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 8 ]", "3" ) {}
156*89c4ff92SAndroid Build Coastguard Worker };
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ConcatenationFixtureDim3, "ParseConcatenationDim3")
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(
161*89c4ff92SAndroid Build Coastguard Worker         0,
162*89c4ff92SAndroid Build Coastguard Worker         { { "inputTensor1", {  0,  1,  2,  3,
163*89c4ff92SAndroid Build Coastguard Worker                                4,  5,  6,  7,
164*89c4ff92SAndroid Build Coastguard Worker                                8,  9, 10, 11,
165*89c4ff92SAndroid Build Coastguard Worker                                12, 13, 14, 15,
166*89c4ff92SAndroid Build Coastguard Worker                                16, 17, 18, 19,
167*89c4ff92SAndroid Build Coastguard Worker                                20, 21, 22, 23 } },
168*89c4ff92SAndroid Build Coastguard Worker         { "inputTensor2", {  50, 51, 52, 53,
169*89c4ff92SAndroid Build Coastguard Worker                              54, 55, 56, 57,
170*89c4ff92SAndroid Build Coastguard Worker                              58, 59, 60, 61,
171*89c4ff92SAndroid Build Coastguard Worker                              62, 63, 64, 65,
172*89c4ff92SAndroid Build Coastguard Worker                              66, 67, 68, 69,
173*89c4ff92SAndroid Build Coastguard Worker                              70, 71, 72, 73 } } },
174*89c4ff92SAndroid Build Coastguard Worker         { { "outputTensor", {  0,  1,  2,  3,
175*89c4ff92SAndroid Build Coastguard Worker                                50, 51, 52, 53,
176*89c4ff92SAndroid Build Coastguard Worker                                4,  5,  6,  7,
177*89c4ff92SAndroid Build Coastguard Worker                                54, 55, 56, 57,
178*89c4ff92SAndroid Build Coastguard Worker                                8,  9,  10, 11,
179*89c4ff92SAndroid Build Coastguard Worker                                58, 59, 60, 61,
180*89c4ff92SAndroid Build Coastguard Worker                                12, 13, 14, 15,
181*89c4ff92SAndroid Build Coastguard Worker                                62, 63, 64, 65,
182*89c4ff92SAndroid Build Coastguard Worker                                16, 17, 18, 19,
183*89c4ff92SAndroid Build Coastguard Worker                                66, 67, 68, 69,
184*89c4ff92SAndroid Build Coastguard Worker                                20, 21, 22, 23,
185*89c4ff92SAndroid Build Coastguard Worker                                70, 71, 72, 73 } } });
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker struct ConcatenationFixture3DDim0 : ConcatenationFixture
189*89c4ff92SAndroid Build Coastguard Worker {
ConcatenationFixture3DDim0ConcatenationFixture3DDim0190*89c4ff92SAndroid Build Coastguard Worker     ConcatenationFixture3DDim0() : ConcatenationFixture("[ 1, 2, 3]", "[ 2, 2, 3]", "[ 3, 2, 3]", "0" ) {}
191*89c4ff92SAndroid Build Coastguard Worker };
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ConcatenationFixture3DDim0, "ParseConcatenation3DDim0")
194*89c4ff92SAndroid Build Coastguard Worker {
195*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::QAsymmU8>(
196*89c4ff92SAndroid Build Coastguard Worker         0,
197*89c4ff92SAndroid Build Coastguard Worker         { { "inputTensor1", { 0,  1,  2,  3,  4,  5 } },
198*89c4ff92SAndroid Build Coastguard Worker           { "inputTensor2", { 6,  7,  8,  9, 10, 11,
199*89c4ff92SAndroid Build Coastguard Worker                              12, 13, 14, 15, 16, 17 } } },
200*89c4ff92SAndroid Build Coastguard Worker         { { "outputTensor", { 0,  1,  2,  3,  4,  5,
201*89c4ff92SAndroid Build Coastguard Worker                               6,  7,  8,  9, 10, 11,
202*89c4ff92SAndroid Build Coastguard Worker                              12, 13, 14, 15, 16, 17 } } });
203*89c4ff92SAndroid Build Coastguard Worker }
204*89c4ff92SAndroid Build Coastguard Worker 
205*89c4ff92SAndroid Build Coastguard Worker struct ConcatenationFixture3DDim1 : ConcatenationFixture
206*89c4ff92SAndroid Build Coastguard Worker {
ConcatenationFixture3DDim1ConcatenationFixture3DDim1207*89c4ff92SAndroid Build Coastguard Worker     ConcatenationFixture3DDim1() : ConcatenationFixture("[ 1, 2, 3]", "[ 1, 4, 3]", "[ 1, 6, 3]", "1" ) {}
208*89c4ff92SAndroid Build Coastguard Worker };
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ConcatenationFixture3DDim1, "ParseConcatenation3DDim1")
211*89c4ff92SAndroid Build Coastguard Worker {
212*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::QAsymmU8>(
213*89c4ff92SAndroid Build Coastguard Worker         0,
214*89c4ff92SAndroid Build Coastguard Worker         { { "inputTensor1", { 0,  1,  2,  3,  4,  5 } },
215*89c4ff92SAndroid Build Coastguard Worker           { "inputTensor2", { 6,  7,  8,  9, 10, 11,
216*89c4ff92SAndroid Build Coastguard Worker                              12, 13, 14, 15, 16, 17 } } },
217*89c4ff92SAndroid Build Coastguard Worker         { { "outputTensor", { 0,  1,  2,  3,  4,  5,
218*89c4ff92SAndroid Build Coastguard Worker                               6,  7,  8,  9, 10, 11,
219*89c4ff92SAndroid Build Coastguard Worker                              12, 13, 14, 15, 16, 17 } } });
220*89c4ff92SAndroid Build Coastguard Worker }
221*89c4ff92SAndroid Build Coastguard Worker 
222*89c4ff92SAndroid Build Coastguard Worker struct ConcatenationFixture3DDim2 : ConcatenationFixture
223*89c4ff92SAndroid Build Coastguard Worker {
ConcatenationFixture3DDim2ConcatenationFixture3DDim2224*89c4ff92SAndroid Build Coastguard Worker     ConcatenationFixture3DDim2() : ConcatenationFixture("[ 1, 2, 3]", "[ 1, 2, 6]", "[ 1, 2, 9]", "2" ) {}
225*89c4ff92SAndroid Build Coastguard Worker };
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ConcatenationFixture3DDim2, "ParseConcatenation3DDim2")
228*89c4ff92SAndroid Build Coastguard Worker {
229*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::QAsymmU8>(
230*89c4ff92SAndroid Build Coastguard Worker         0,
231*89c4ff92SAndroid Build Coastguard Worker         { { "inputTensor1", { 0,  1,  2,
232*89c4ff92SAndroid Build Coastguard Worker                               3,  4,  5 } },
233*89c4ff92SAndroid Build Coastguard Worker           { "inputTensor2", { 6,  7,  8,  9, 10, 11,
234*89c4ff92SAndroid Build Coastguard Worker                              12, 13, 14, 15, 16, 17 } } },
235*89c4ff92SAndroid Build Coastguard Worker         { { "outputTensor", { 0,  1,  2,  6,  7,  8,  9, 10, 11,
236*89c4ff92SAndroid Build Coastguard Worker                               3,  4,  5, 12, 13, 14, 15, 16, 17 } } });
237*89c4ff92SAndroid Build Coastguard Worker }
238*89c4ff92SAndroid Build Coastguard Worker 
239*89c4ff92SAndroid Build Coastguard Worker }
240