xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/FullyConnected.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 
9 TEST_SUITE("TensorflowLiteParser_FullyConnected")
10 {
11 struct FullyConnectedFixture : public ParserFlatbuffersFixture
12 {
FullyConnectedFixtureFullyConnectedFixture13     explicit FullyConnectedFixture(const std::string& inputShape,
14                                    const std::string& outputShape,
15                                    const std::string& filterShape,
16                                    const std::string& filterData,
17                                    const std::string biasShape = "",
18                                    const std::string biasData = "",
19                                    const std::string dataType = "UINT8",
20                                    const std::string weightsDataType = "UINT8",
21                                    const std::string biasDataType = "INT32")
22     {
23         std::string inputTensors = "[ 0, 2 ]";
24         std::string biasTensor = "";
25         std::string biasBuffer = "";
26         if (biasShape.size() > 0 && biasData.size() > 0)
27         {
28             inputTensors = "[ 0, 2, 3 ]";
29             biasTensor = R"(
30                         {
31                             "shape": )" + biasShape + R"( ,
32                             "type": )" + biasDataType + R"(,
33                             "buffer": 3,
34                             "name": "biasTensor",
35                             "quantization": {
36                                 "min": [ 0.0 ],
37                                 "max": [ 255.0 ],
38                                 "scale": [ 1.0 ],
39                                 "zero_point": [ 0 ],
40                             }
41                         } )";
42             biasBuffer = R"(
43                     { "data": )" + biasData + R"(, }, )";
44         }
45         m_JsonString = R"(
46             {
47                 "version": 3,
48                 "operator_codes": [ { "builtin_code": "FULLY_CONNECTED" } ],
49                 "subgraphs": [ {
50                     "tensors": [
51                         {
52                             "shape": )" + inputShape + R"(,
53                             "type": )" + dataType + R"(,
54                             "buffer": 0,
55                             "name": "inputTensor",
56                             "quantization": {
57                                 "min": [ 0.0 ],
58                                 "max": [ 255.0 ],
59                                 "scale": [ 1.0 ],
60                                 "zero_point": [ 0 ],
61                             }
62                         },
63                         {
64                             "shape": )" + outputShape + R"(,
65                             "type": )" + dataType + R"(,
66                             "buffer": 1,
67                             "name": "outputTensor",
68                             "quantization": {
69                                 "min": [ 0.0 ],
70                                 "max": [ 511.0 ],
71                                 "scale": [ 2.0 ],
72                                 "zero_point": [ 0 ],
73                             }
74                         },
75                         {
76                             "shape": )" + filterShape + R"(,
77                             "type": )" + weightsDataType + R"(,
78                             "buffer": 2,
79                             "name": "filterTensor",
80                             "quantization": {
81                                 "min": [ 0.0 ],
82                                 "max": [ 255.0 ],
83                                 "scale": [ 1.0 ],
84                                 "zero_point": [ 0 ],
85                             }
86                         }, )" + biasTensor + R"(
87                     ],
88                     "inputs": [ 0 ],
89                     "outputs": [ 1 ],
90                     "operators": [
91                         {
92                             "opcode_index": 0,
93                             "inputs": )" + inputTensors + R"(,
94                             "outputs": [ 1 ],
95                             "builtin_options_type": "FullyConnectedOptions",
96                             "builtin_options": {
97                                 "fused_activation_function": "NONE"
98                             },
99                             "custom_options_format": "FLEXBUFFERS"
100                         }
101                     ],
102                 } ],
103                 "buffers" : [
104                     { },
105                     { },
106                     { "data": )" + filterData + R"(, }, )"
107                        + biasBuffer + R"(
108                 ]
109             }
110         )";
111         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
112     }
113 };
114 
115 struct FullyConnectedWithNoBiasFixture : FullyConnectedFixture
116 {
FullyConnectedWithNoBiasFixtureFullyConnectedWithNoBiasFixture117     FullyConnectedWithNoBiasFixture()
118         : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
119                                 "[ 1, 1 ]",           // outputShape
120                                 "[ 1, 4 ]",           // filterShape
121                                 "[ 2, 3, 4, 5 ]")     // filterData
122     {}
123 };
124 
125 TEST_CASE_FIXTURE(FullyConnectedWithNoBiasFixture, "FullyConnectedWithNoBias")
126 {
127     RunTest<2, armnn::DataType::QAsymmU8>(
128         0,
129         { 10, 20, 30, 40 },
130         { 400/2 });
131 }
132 
133 struct FullyConnectedWithBiasFixture : FullyConnectedFixture
134 {
FullyConnectedWithBiasFixtureFullyConnectedWithBiasFixture135     FullyConnectedWithBiasFixture()
136         : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
137                                 "[ 1, 1 ]",           // outputShape
138                                 "[ 1, 4 ]",           // filterShape
139                                 "[ 2, 3, 4, 5 ]",     // filterData
140                                 "[ 1 ]",              // biasShape
141                                 "[ 10, 0, 0, 0 ]" )   // biasData
142     {}
143 };
144 
145 TEST_CASE_FIXTURE(FullyConnectedWithBiasFixture, "ParseFullyConnectedWithBias")
146 {
147     RunTest<2, armnn::DataType::QAsymmU8>(
148         0,
149         { 10, 20, 30, 40 },
150         { (400+10)/2 });
151 }
152 
153 struct FullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
154 {
FullyConnectedWithBiasMultipleOutputsFixtureFullyConnectedWithBiasMultipleOutputsFixture155     FullyConnectedWithBiasMultipleOutputsFixture()
156             : FullyConnectedFixture("[ 1, 4, 2, 1 ]",     // inputShape
157                                     "[ 2, 1 ]",           // outputShape
158                                     "[ 1, 4 ]",           // filterShape
159                                     "[ 2, 3, 4, 5 ]",     // filterData
160                                     "[ 1 ]",              // biasShape
161                                     "[ 10, 0, 0, 0 ]" )   // biasData
162     {}
163 };
164 
165 TEST_CASE_FIXTURE(FullyConnectedWithBiasMultipleOutputsFixture, "FullyConnectedWithBiasMultipleOutputs")
166 {
167     RunTest<2, armnn::DataType::QAsymmU8>(
168             0,
169             { 1, 2, 3, 4, 10, 20, 30, 40 },
170             { (40+10)/2, (400+10)/2 });
171 }
172 
173 struct DynamicFullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
174 {
DynamicFullyConnectedWithBiasMultipleOutputsFixtureDynamicFullyConnectedWithBiasMultipleOutputsFixture175     DynamicFullyConnectedWithBiasMultipleOutputsFixture()
176         : FullyConnectedFixture("[ 1, 4, 2, 1 ]",     // inputShape
177                                 "[ ]",               // outputShape
178                                 "[ 1, 4 ]",           // filterShape
179                                 "[ 2, 3, 4, 5 ]",     // filterData
180                                 "[ 1 ]",              // biasShape
181                                 "[ 10, 0, 0, 0 ]" )   // biasData
182     { }
183 };
184 
185 TEST_CASE_FIXTURE(
186     DynamicFullyConnectedWithBiasMultipleOutputsFixture, "DynamicFullyConnectedWithBiasMultipleOutputs")
187 {
188     RunTest<2,
189             armnn::DataType::QAsymmU8,
190             armnn::DataType::QAsymmU8>(0,
191                                       { { "inputTensor", { 1, 2, 3, 4, 10, 20, 30, 40} } },
192                                       { { "outputTensor", { (40+10)/2, (400+10)/2 } } },
193                                       true);
194 }
195 
196 
197 struct FullyConnectedNonConstWeightsFixture : public ParserFlatbuffersFixture
198 {
FullyConnectedNonConstWeightsFixtureFullyConnectedNonConstWeightsFixture199     explicit FullyConnectedNonConstWeightsFixture(const std::string& inputShape,
200                                                   const std::string& outputShape,
201                                                   const std::string& filterShape,
202                                                   const std::string biasShape = "")
203     {
204         std::string inputTensors = "[ 0, 1 ]";
205         std::string biasTensor = "";
206         std::string biasBuffer = "";
207         std::string outputs = "2";
208         if (biasShape.size() > 0)
209         {
210             inputTensors = "[ 0, 1, 2 ]";
211             biasTensor = R"(
212                        {
213                       "shape": )" + biasShape + R"(,
214                       "type": "INT32",
215                       "buffer": 2,
216                       "name": "bias",
217                       "quantization": {
218                         "scale": [ 1.0 ],
219                         "zero_point": [ 0 ],
220                         "details_type": 0,
221                         "quantized_dimension": 0
222                       },
223                       "is_variable": true
224                     }, )";
225 
226             biasBuffer = R"(,{ "data": [] } )";
227             outputs = "3";
228         }
229         m_JsonString = R"(
230             {
231               "version": 3,
232               "operator_codes": [
233                 {
234                   "builtin_code": "FULLY_CONNECTED",
235                   "version": 1
236                 }
237               ],
238               "subgraphs": [
239                 {
240                   "tensors": [
241                     {
242                       "shape": )" + inputShape + R"(,
243                       "type": "INT8",
244                       "buffer": 0,
245                       "name": "input_0",
246                       "quantization": {
247                         "scale": [ 1.0 ],
248                         "zero_point": [ 0 ],
249                         "details_type": 0,
250                         "quantized_dimension": 0
251                       },
252                     },
253                     {
254                       "shape": )" + filterShape + R"(,
255                       "type": "INT8",
256                       "buffer": 1,
257                       "name": "weights",
258                       "quantization": {
259                         "scale": [ 1.0 ],
260                         "zero_point": [ 0 ],
261                         "details_type": 0,
262                         "quantized_dimension": 0
263                       },
264                     },
265                     )" + biasTensor + R"(
266                     {
267                       "shape": )" + outputShape + R"(,
268                       "type": "INT8",
269                       "buffer": 0,
270                       "name": "output",
271                       "quantization": {
272                         "scale": [
273                           2.0
274                         ],
275                         "zero_point": [
276                           0
277                         ],
278                         "details_type": 0,
279                         "quantized_dimension": 0
280                       },
281                     }
282                   ],
283                   "inputs": )" + inputTensors + R"(,
284                   "outputs": [ )" + outputs + R"( ],
285                   "operators": [
286                     {
287                       "opcode_index": 0,
288                       "inputs": )" + inputTensors + R"(,
289                       "outputs": [ )" + outputs + R"( ],
290                       "builtin_options_type": "FullyConnectedOptions",
291                       "builtin_options": {
292                         "fused_activation_function": "NONE",
293                         "weights_format": "DEFAULT",
294                         "keep_num_dims": false,
295                         "asymmetric_quantize_inputs": false
296                       },
297                       "custom_options_format": "FLEXBUFFERS"
298                     }
299                   ]
300                 }
301               ],
302               "description": "ArmnnDelegate: FullyConnected Operator Model",
303               "buffers": [
304                 {
305                   "data": []
306                 },
307                 {
308                   "data": []
309                 }
310                 )" + biasBuffer + R"(
311               ]
312             }
313             )";
314         Setup();
315     }
316 };
317 
318 struct FullyConnectedNonConstWeights : FullyConnectedNonConstWeightsFixture
319 {
FullyConnectedNonConstWeightsFullyConnectedNonConstWeights320     FullyConnectedNonConstWeights()
321             : FullyConnectedNonConstWeightsFixture("[ 1, 4, 1, 1 ]",     // inputShape
322                                                    "[ 1, 1 ]",           // outputShape
323                                                    "[ 1, 4 ]",           // filterShape
324                                                    "[ 1 ]" )             // biasShape
325 
326     {}
327 };
328 
329 TEST_CASE_FIXTURE(FullyConnectedNonConstWeights, "ParseFullyConnectedNonConstWeights")
330 {
331     RunTest<2, armnn::DataType::QAsymmS8,
332             armnn::DataType::Signed32,
333             armnn::DataType::QAsymmS8>(
334             0,
335             {{{"input_0", { 1, 2, 3, 4 }},{"weights", { 2, 3, 4, 5 }}}},
336             {{"bias", { 10 }}},
337             {{"output", { 25 }}});
338 }
339 
340 struct FullyConnectedNonConstWeightsNoBias : FullyConnectedNonConstWeightsFixture
341 {
FullyConnectedNonConstWeightsNoBiasFullyConnectedNonConstWeightsNoBias342     FullyConnectedNonConstWeightsNoBias()
343             : FullyConnectedNonConstWeightsFixture("[ 1, 4, 1, 1 ]",     // inputShape
344                                                    "[ 1, 1 ]",           // outputShape
345                                                    "[ 1, 4 ]")           // filterShape
346 
347     {}
348 };
349 
350 TEST_CASE_FIXTURE(FullyConnectedNonConstWeightsNoBias, "ParseFullyConnectedNonConstWeightsNoBias")
351 {
352     RunTest<2, armnn::DataType::QAsymmS8,
353             armnn::DataType::QAsymmS8>(
354             0,
355             {{{"input_0", { 1, 2, 3, 4 }},{"weights", { 2, 3, 4, 5 }}}},
356             {{"output", { 20 }}});
357 }
358 
359 struct FullyConnectedWeightsBiasFloat : FullyConnectedFixture
360 {
FullyConnectedWeightsBiasFloatFullyConnectedWeightsBiasFloat361     FullyConnectedWeightsBiasFloat()
362             : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
363                                     "[ 1, 1, 1, 1 ]",     // outputShape
364                                     "[ 1, 4 ]",           // filterShape
365                                     "[ 2, 3, 4, 5 ]",     // filterData
366                                     "[ 1 ]",              // biasShape
367                                     "[ 10, 0, 0, 0 ]",    // filterShape
368                                     "FLOAT32",            // input and output dataType
369                                     "INT8",               // weights dataType
370                                     "FLOAT32")            // bias dataType
371     {}
372 };
373 
374 TEST_CASE_FIXTURE(FullyConnectedWeightsBiasFloat, "FullyConnectedWeightsBiasFloat")
375 {
376     RunTest<4, armnn::DataType::Float32>(
377             0,
378             { 10, 20, 30, 40 },
379             { 400 });
380 }
381 
382 }
383