xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/FullyConnected.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_FullyConnected")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedFixture : public ParserFlatbuffersFixture
12*89c4ff92SAndroid Build Coastguard Worker {
FullyConnectedFixtureFullyConnectedFixture13*89c4ff92SAndroid Build Coastguard Worker     explicit FullyConnectedFixture(const std::string& inputShape,
14*89c4ff92SAndroid Build Coastguard Worker                                    const std::string& outputShape,
15*89c4ff92SAndroid Build Coastguard Worker                                    const std::string& filterShape,
16*89c4ff92SAndroid Build Coastguard Worker                                    const std::string& filterData,
17*89c4ff92SAndroid Build Coastguard Worker                                    const std::string biasShape = "",
18*89c4ff92SAndroid Build Coastguard Worker                                    const std::string biasData = "",
19*89c4ff92SAndroid Build Coastguard Worker                                    const std::string dataType = "UINT8",
20*89c4ff92SAndroid Build Coastguard Worker                                    const std::string weightsDataType = "UINT8",
21*89c4ff92SAndroid Build Coastguard Worker                                    const std::string biasDataType = "INT32")
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker         std::string inputTensors = "[ 0, 2 ]";
24*89c4ff92SAndroid Build Coastguard Worker         std::string biasTensor = "";
25*89c4ff92SAndroid Build Coastguard Worker         std::string biasBuffer = "";
26*89c4ff92SAndroid Build Coastguard Worker         if (biasShape.size() > 0 && biasData.size() > 0)
27*89c4ff92SAndroid Build Coastguard Worker         {
28*89c4ff92SAndroid Build Coastguard Worker             inputTensors = "[ 0, 2, 3 ]";
29*89c4ff92SAndroid Build Coastguard Worker             biasTensor = R"(
30*89c4ff92SAndroid Build Coastguard Worker                         {
31*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + biasShape + R"( ,
32*89c4ff92SAndroid Build Coastguard Worker                             "type": )" + biasDataType + R"(,
33*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 3,
34*89c4ff92SAndroid Build Coastguard Worker                             "name": "biasTensor",
35*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
36*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
37*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
38*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
39*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
40*89c4ff92SAndroid Build Coastguard Worker                             }
41*89c4ff92SAndroid Build Coastguard Worker                         } )";
42*89c4ff92SAndroid Build Coastguard Worker             biasBuffer = R"(
43*89c4ff92SAndroid Build Coastguard Worker                     { "data": )" + biasData + R"(, }, )";
44*89c4ff92SAndroid Build Coastguard Worker         }
45*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
46*89c4ff92SAndroid Build Coastguard Worker             {
47*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
48*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "FULLY_CONNECTED" } ],
49*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {
50*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
51*89c4ff92SAndroid Build Coastguard Worker                         {
52*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + inputShape + R"(,
53*89c4ff92SAndroid Build Coastguard Worker                             "type": )" + dataType + R"(,
54*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
55*89c4ff92SAndroid Build Coastguard Worker                             "name": "inputTensor",
56*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
57*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
58*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
59*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
60*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
61*89c4ff92SAndroid Build Coastguard Worker                             }
62*89c4ff92SAndroid Build Coastguard Worker                         },
63*89c4ff92SAndroid Build Coastguard Worker                         {
64*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + outputShape + R"(,
65*89c4ff92SAndroid Build Coastguard Worker                             "type": )" + dataType + R"(,
66*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
67*89c4ff92SAndroid Build Coastguard Worker                             "name": "outputTensor",
68*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
69*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
70*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 511.0 ],
71*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 2.0 ],
72*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
73*89c4ff92SAndroid Build Coastguard Worker                             }
74*89c4ff92SAndroid Build Coastguard Worker                         },
75*89c4ff92SAndroid Build Coastguard Worker                         {
76*89c4ff92SAndroid Build Coastguard Worker                             "shape": )" + filterShape + R"(,
77*89c4ff92SAndroid Build Coastguard Worker                             "type": )" + weightsDataType + R"(,
78*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 2,
79*89c4ff92SAndroid Build Coastguard Worker                             "name": "filterTensor",
80*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
81*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
82*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
83*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
84*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
85*89c4ff92SAndroid Build Coastguard Worker                             }
86*89c4ff92SAndroid Build Coastguard Worker                         }, )" + biasTensor + R"(
87*89c4ff92SAndroid Build Coastguard Worker                     ],
88*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0 ],
89*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ 1 ],
90*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
91*89c4ff92SAndroid Build Coastguard Worker                         {
92*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 0,
93*89c4ff92SAndroid Build Coastguard Worker                             "inputs": )" + inputTensors + R"(,
94*89c4ff92SAndroid Build Coastguard Worker                             "outputs": [ 1 ],
95*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options_type": "FullyConnectedOptions",
96*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options": {
97*89c4ff92SAndroid Build Coastguard Worker                                 "fused_activation_function": "NONE"
98*89c4ff92SAndroid Build Coastguard Worker                             },
99*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
100*89c4ff92SAndroid Build Coastguard Worker                         }
101*89c4ff92SAndroid Build Coastguard Worker                     ],
102*89c4ff92SAndroid Build Coastguard Worker                 } ],
103*89c4ff92SAndroid Build Coastguard Worker                 "buffers" : [
104*89c4ff92SAndroid Build Coastguard Worker                     { },
105*89c4ff92SAndroid Build Coastguard Worker                     { },
106*89c4ff92SAndroid Build Coastguard Worker                     { "data": )" + filterData + R"(, }, )"
107*89c4ff92SAndroid Build Coastguard Worker                        + biasBuffer + R"(
108*89c4ff92SAndroid Build Coastguard Worker                 ]
109*89c4ff92SAndroid Build Coastguard Worker             }
110*89c4ff92SAndroid Build Coastguard Worker         )";
111*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
112*89c4ff92SAndroid Build Coastguard Worker     }
113*89c4ff92SAndroid Build Coastguard Worker };
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedWithNoBiasFixture : FullyConnectedFixture
116*89c4ff92SAndroid Build Coastguard Worker {
FullyConnectedWithNoBiasFixtureFullyConnectedWithNoBiasFixture117*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedWithNoBiasFixture()
118*89c4ff92SAndroid Build Coastguard Worker         : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
119*89c4ff92SAndroid Build Coastguard Worker                                 "[ 1, 1 ]",           // outputShape
120*89c4ff92SAndroid Build Coastguard Worker                                 "[ 1, 4 ]",           // filterShape
121*89c4ff92SAndroid Build Coastguard Worker                                 "[ 2, 3, 4, 5 ]")     // filterData
122*89c4ff92SAndroid Build Coastguard Worker     {}
123*89c4ff92SAndroid Build Coastguard Worker };
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FullyConnectedWithNoBiasFixture, "FullyConnectedWithNoBias")
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(
128*89c4ff92SAndroid Build Coastguard Worker         0,
129*89c4ff92SAndroid Build Coastguard Worker         { 10, 20, 30, 40 },
130*89c4ff92SAndroid Build Coastguard Worker         { 400/2 });
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedWithBiasFixture : FullyConnectedFixture
134*89c4ff92SAndroid Build Coastguard Worker {
FullyConnectedWithBiasFixtureFullyConnectedWithBiasFixture135*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedWithBiasFixture()
136*89c4ff92SAndroid Build Coastguard Worker         : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
137*89c4ff92SAndroid Build Coastguard Worker                                 "[ 1, 1 ]",           // outputShape
138*89c4ff92SAndroid Build Coastguard Worker                                 "[ 1, 4 ]",           // filterShape
139*89c4ff92SAndroid Build Coastguard Worker                                 "[ 2, 3, 4, 5 ]",     // filterData
140*89c4ff92SAndroid Build Coastguard Worker                                 "[ 1 ]",              // biasShape
141*89c4ff92SAndroid Build Coastguard Worker                                 "[ 10, 0, 0, 0 ]" )   // biasData
142*89c4ff92SAndroid Build Coastguard Worker     {}
143*89c4ff92SAndroid Build Coastguard Worker };
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FullyConnectedWithBiasFixture, "ParseFullyConnectedWithBias")
146*89c4ff92SAndroid Build Coastguard Worker {
147*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(
148*89c4ff92SAndroid Build Coastguard Worker         0,
149*89c4ff92SAndroid Build Coastguard Worker         { 10, 20, 30, 40 },
150*89c4ff92SAndroid Build Coastguard Worker         { (400+10)/2 });
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
154*89c4ff92SAndroid Build Coastguard Worker {
FullyConnectedWithBiasMultipleOutputsFixtureFullyConnectedWithBiasMultipleOutputsFixture155*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedWithBiasMultipleOutputsFixture()
156*89c4ff92SAndroid Build Coastguard Worker             : FullyConnectedFixture("[ 1, 4, 2, 1 ]",     // inputShape
157*89c4ff92SAndroid Build Coastguard Worker                                     "[ 2, 1 ]",           // outputShape
158*89c4ff92SAndroid Build Coastguard Worker                                     "[ 1, 4 ]",           // filterShape
159*89c4ff92SAndroid Build Coastguard Worker                                     "[ 2, 3, 4, 5 ]",     // filterData
160*89c4ff92SAndroid Build Coastguard Worker                                     "[ 1 ]",              // biasShape
161*89c4ff92SAndroid Build Coastguard Worker                                     "[ 10, 0, 0, 0 ]" )   // biasData
162*89c4ff92SAndroid Build Coastguard Worker     {}
163*89c4ff92SAndroid Build Coastguard Worker };
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FullyConnectedWithBiasMultipleOutputsFixture, "FullyConnectedWithBiasMultipleOutputs")
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(
168*89c4ff92SAndroid Build Coastguard Worker             0,
169*89c4ff92SAndroid Build Coastguard Worker             { 1, 2, 3, 4, 10, 20, 30, 40 },
170*89c4ff92SAndroid Build Coastguard Worker             { (40+10)/2, (400+10)/2 });
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker struct DynamicFullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
174*89c4ff92SAndroid Build Coastguard Worker {
DynamicFullyConnectedWithBiasMultipleOutputsFixtureDynamicFullyConnectedWithBiasMultipleOutputsFixture175*89c4ff92SAndroid Build Coastguard Worker     DynamicFullyConnectedWithBiasMultipleOutputsFixture()
176*89c4ff92SAndroid Build Coastguard Worker         : FullyConnectedFixture("[ 1, 4, 2, 1 ]",     // inputShape
177*89c4ff92SAndroid Build Coastguard Worker                                 "[ ]",               // outputShape
178*89c4ff92SAndroid Build Coastguard Worker                                 "[ 1, 4 ]",           // filterShape
179*89c4ff92SAndroid Build Coastguard Worker                                 "[ 2, 3, 4, 5 ]",     // filterData
180*89c4ff92SAndroid Build Coastguard Worker                                 "[ 1 ]",              // biasShape
181*89c4ff92SAndroid Build Coastguard Worker                                 "[ 10, 0, 0, 0 ]" )   // biasData
182*89c4ff92SAndroid Build Coastguard Worker     { }
183*89c4ff92SAndroid Build Coastguard Worker };
184*89c4ff92SAndroid Build Coastguard Worker 
185*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(
186*89c4ff92SAndroid Build Coastguard Worker     DynamicFullyConnectedWithBiasMultipleOutputsFixture, "DynamicFullyConnectedWithBiasMultipleOutputs")
187*89c4ff92SAndroid Build Coastguard Worker {
188*89c4ff92SAndroid Build Coastguard Worker     RunTest<2,
189*89c4ff92SAndroid Build Coastguard Worker             armnn::DataType::QAsymmU8,
190*89c4ff92SAndroid Build Coastguard Worker             armnn::DataType::QAsymmU8>(0,
191*89c4ff92SAndroid Build Coastguard Worker                                       { { "inputTensor", { 1, 2, 3, 4, 10, 20, 30, 40} } },
192*89c4ff92SAndroid Build Coastguard Worker                                       { { "outputTensor", { (40+10)/2, (400+10)/2 } } },
193*89c4ff92SAndroid Build Coastguard Worker                                       true);
194*89c4ff92SAndroid Build Coastguard Worker }
195*89c4ff92SAndroid Build Coastguard Worker 
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedNonConstWeightsFixture : public ParserFlatbuffersFixture
198*89c4ff92SAndroid Build Coastguard Worker {
FullyConnectedNonConstWeightsFixtureFullyConnectedNonConstWeightsFixture199*89c4ff92SAndroid Build Coastguard Worker     explicit FullyConnectedNonConstWeightsFixture(const std::string& inputShape,
200*89c4ff92SAndroid Build Coastguard Worker                                                   const std::string& outputShape,
201*89c4ff92SAndroid Build Coastguard Worker                                                   const std::string& filterShape,
202*89c4ff92SAndroid Build Coastguard Worker                                                   const std::string biasShape = "")
203*89c4ff92SAndroid Build Coastguard Worker     {
204*89c4ff92SAndroid Build Coastguard Worker         std::string inputTensors = "[ 0, 1 ]";
205*89c4ff92SAndroid Build Coastguard Worker         std::string biasTensor = "";
206*89c4ff92SAndroid Build Coastguard Worker         std::string biasBuffer = "";
207*89c4ff92SAndroid Build Coastguard Worker         std::string outputs = "2";
208*89c4ff92SAndroid Build Coastguard Worker         if (biasShape.size() > 0)
209*89c4ff92SAndroid Build Coastguard Worker         {
210*89c4ff92SAndroid Build Coastguard Worker             inputTensors = "[ 0, 1, 2 ]";
211*89c4ff92SAndroid Build Coastguard Worker             biasTensor = R"(
212*89c4ff92SAndroid Build Coastguard Worker                        {
213*89c4ff92SAndroid Build Coastguard Worker                       "shape": )" + biasShape + R"(,
214*89c4ff92SAndroid Build Coastguard Worker                       "type": "INT32",
215*89c4ff92SAndroid Build Coastguard Worker                       "buffer": 2,
216*89c4ff92SAndroid Build Coastguard Worker                       "name": "bias",
217*89c4ff92SAndroid Build Coastguard Worker                       "quantization": {
218*89c4ff92SAndroid Build Coastguard Worker                         "scale": [ 1.0 ],
219*89c4ff92SAndroid Build Coastguard Worker                         "zero_point": [ 0 ],
220*89c4ff92SAndroid Build Coastguard Worker                         "details_type": 0,
221*89c4ff92SAndroid Build Coastguard Worker                         "quantized_dimension": 0
222*89c4ff92SAndroid Build Coastguard Worker                       },
223*89c4ff92SAndroid Build Coastguard Worker                       "is_variable": true
224*89c4ff92SAndroid Build Coastguard Worker                     }, )";
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker             biasBuffer = R"(,{ "data": [] } )";
227*89c4ff92SAndroid Build Coastguard Worker             outputs = "3";
228*89c4ff92SAndroid Build Coastguard Worker         }
229*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
230*89c4ff92SAndroid Build Coastguard Worker             {
231*89c4ff92SAndroid Build Coastguard Worker               "version": 3,
232*89c4ff92SAndroid Build Coastguard Worker               "operator_codes": [
233*89c4ff92SAndroid Build Coastguard Worker                 {
234*89c4ff92SAndroid Build Coastguard Worker                   "builtin_code": "FULLY_CONNECTED",
235*89c4ff92SAndroid Build Coastguard Worker                   "version": 1
236*89c4ff92SAndroid Build Coastguard Worker                 }
237*89c4ff92SAndroid Build Coastguard Worker               ],
238*89c4ff92SAndroid Build Coastguard Worker               "subgraphs": [
239*89c4ff92SAndroid Build Coastguard Worker                 {
240*89c4ff92SAndroid Build Coastguard Worker                   "tensors": [
241*89c4ff92SAndroid Build Coastguard Worker                     {
242*89c4ff92SAndroid Build Coastguard Worker                       "shape": )" + inputShape + R"(,
243*89c4ff92SAndroid Build Coastguard Worker                       "type": "INT8",
244*89c4ff92SAndroid Build Coastguard Worker                       "buffer": 0,
245*89c4ff92SAndroid Build Coastguard Worker                       "name": "input_0",
246*89c4ff92SAndroid Build Coastguard Worker                       "quantization": {
247*89c4ff92SAndroid Build Coastguard Worker                         "scale": [ 1.0 ],
248*89c4ff92SAndroid Build Coastguard Worker                         "zero_point": [ 0 ],
249*89c4ff92SAndroid Build Coastguard Worker                         "details_type": 0,
250*89c4ff92SAndroid Build Coastguard Worker                         "quantized_dimension": 0
251*89c4ff92SAndroid Build Coastguard Worker                       },
252*89c4ff92SAndroid Build Coastguard Worker                     },
253*89c4ff92SAndroid Build Coastguard Worker                     {
254*89c4ff92SAndroid Build Coastguard Worker                       "shape": )" + filterShape + R"(,
255*89c4ff92SAndroid Build Coastguard Worker                       "type": "INT8",
256*89c4ff92SAndroid Build Coastguard Worker                       "buffer": 1,
257*89c4ff92SAndroid Build Coastguard Worker                       "name": "weights",
258*89c4ff92SAndroid Build Coastguard Worker                       "quantization": {
259*89c4ff92SAndroid Build Coastguard Worker                         "scale": [ 1.0 ],
260*89c4ff92SAndroid Build Coastguard Worker                         "zero_point": [ 0 ],
261*89c4ff92SAndroid Build Coastguard Worker                         "details_type": 0,
262*89c4ff92SAndroid Build Coastguard Worker                         "quantized_dimension": 0
263*89c4ff92SAndroid Build Coastguard Worker                       },
264*89c4ff92SAndroid Build Coastguard Worker                     },
265*89c4ff92SAndroid Build Coastguard Worker                     )" + biasTensor + R"(
266*89c4ff92SAndroid Build Coastguard Worker                     {
267*89c4ff92SAndroid Build Coastguard Worker                       "shape": )" + outputShape + R"(,
268*89c4ff92SAndroid Build Coastguard Worker                       "type": "INT8",
269*89c4ff92SAndroid Build Coastguard Worker                       "buffer": 0,
270*89c4ff92SAndroid Build Coastguard Worker                       "name": "output",
271*89c4ff92SAndroid Build Coastguard Worker                       "quantization": {
272*89c4ff92SAndroid Build Coastguard Worker                         "scale": [
273*89c4ff92SAndroid Build Coastguard Worker                           2.0
274*89c4ff92SAndroid Build Coastguard Worker                         ],
275*89c4ff92SAndroid Build Coastguard Worker                         "zero_point": [
276*89c4ff92SAndroid Build Coastguard Worker                           0
277*89c4ff92SAndroid Build Coastguard Worker                         ],
278*89c4ff92SAndroid Build Coastguard Worker                         "details_type": 0,
279*89c4ff92SAndroid Build Coastguard Worker                         "quantized_dimension": 0
280*89c4ff92SAndroid Build Coastguard Worker                       },
281*89c4ff92SAndroid Build Coastguard Worker                     }
282*89c4ff92SAndroid Build Coastguard Worker                   ],
283*89c4ff92SAndroid Build Coastguard Worker                   "inputs": )" + inputTensors + R"(,
284*89c4ff92SAndroid Build Coastguard Worker                   "outputs": [ )" + outputs + R"( ],
285*89c4ff92SAndroid Build Coastguard Worker                   "operators": [
286*89c4ff92SAndroid Build Coastguard Worker                     {
287*89c4ff92SAndroid Build Coastguard Worker                       "opcode_index": 0,
288*89c4ff92SAndroid Build Coastguard Worker                       "inputs": )" + inputTensors + R"(,
289*89c4ff92SAndroid Build Coastguard Worker                       "outputs": [ )" + outputs + R"( ],
290*89c4ff92SAndroid Build Coastguard Worker                       "builtin_options_type": "FullyConnectedOptions",
291*89c4ff92SAndroid Build Coastguard Worker                       "builtin_options": {
292*89c4ff92SAndroid Build Coastguard Worker                         "fused_activation_function": "NONE",
293*89c4ff92SAndroid Build Coastguard Worker                         "weights_format": "DEFAULT",
294*89c4ff92SAndroid Build Coastguard Worker                         "keep_num_dims": false,
295*89c4ff92SAndroid Build Coastguard Worker                         "asymmetric_quantize_inputs": false
296*89c4ff92SAndroid Build Coastguard Worker                       },
297*89c4ff92SAndroid Build Coastguard Worker                       "custom_options_format": "FLEXBUFFERS"
298*89c4ff92SAndroid Build Coastguard Worker                     }
299*89c4ff92SAndroid Build Coastguard Worker                   ]
300*89c4ff92SAndroid Build Coastguard Worker                 }
301*89c4ff92SAndroid Build Coastguard Worker               ],
302*89c4ff92SAndroid Build Coastguard Worker               "description": "ArmnnDelegate: FullyConnected Operator Model",
303*89c4ff92SAndroid Build Coastguard Worker               "buffers": [
304*89c4ff92SAndroid Build Coastguard Worker                 {
305*89c4ff92SAndroid Build Coastguard Worker                   "data": []
306*89c4ff92SAndroid Build Coastguard Worker                 },
307*89c4ff92SAndroid Build Coastguard Worker                 {
308*89c4ff92SAndroid Build Coastguard Worker                   "data": []
309*89c4ff92SAndroid Build Coastguard Worker                 }
310*89c4ff92SAndroid Build Coastguard Worker                 )" + biasBuffer + R"(
311*89c4ff92SAndroid Build Coastguard Worker               ]
312*89c4ff92SAndroid Build Coastguard Worker             }
313*89c4ff92SAndroid Build Coastguard Worker             )";
314*89c4ff92SAndroid Build Coastguard Worker         Setup();
315*89c4ff92SAndroid Build Coastguard Worker     }
316*89c4ff92SAndroid Build Coastguard Worker };
317*89c4ff92SAndroid Build Coastguard Worker 
318*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedNonConstWeights : FullyConnectedNonConstWeightsFixture
319*89c4ff92SAndroid Build Coastguard Worker {
FullyConnectedNonConstWeightsFullyConnectedNonConstWeights320*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedNonConstWeights()
321*89c4ff92SAndroid Build Coastguard Worker             : FullyConnectedNonConstWeightsFixture("[ 1, 4, 1, 1 ]",     // inputShape
322*89c4ff92SAndroid Build Coastguard Worker                                                    "[ 1, 1 ]",           // outputShape
323*89c4ff92SAndroid Build Coastguard Worker                                                    "[ 1, 4 ]",           // filterShape
324*89c4ff92SAndroid Build Coastguard Worker                                                    "[ 1 ]" )             // biasShape
325*89c4ff92SAndroid Build Coastguard Worker 
326*89c4ff92SAndroid Build Coastguard Worker     {}
327*89c4ff92SAndroid Build Coastguard Worker };
328*89c4ff92SAndroid Build Coastguard Worker 
329*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FullyConnectedNonConstWeights, "ParseFullyConnectedNonConstWeights")
330*89c4ff92SAndroid Build Coastguard Worker {
331*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmS8,
332*89c4ff92SAndroid Build Coastguard Worker             armnn::DataType::Signed32,
333*89c4ff92SAndroid Build Coastguard Worker             armnn::DataType::QAsymmS8>(
334*89c4ff92SAndroid Build Coastguard Worker             0,
335*89c4ff92SAndroid Build Coastguard Worker             {{{"input_0", { 1, 2, 3, 4 }},{"weights", { 2, 3, 4, 5 }}}},
336*89c4ff92SAndroid Build Coastguard Worker             {{"bias", { 10 }}},
337*89c4ff92SAndroid Build Coastguard Worker             {{"output", { 25 }}});
338*89c4ff92SAndroid Build Coastguard Worker }
339*89c4ff92SAndroid Build Coastguard Worker 
340*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedNonConstWeightsNoBias : FullyConnectedNonConstWeightsFixture
341*89c4ff92SAndroid Build Coastguard Worker {
FullyConnectedNonConstWeightsNoBiasFullyConnectedNonConstWeightsNoBias342*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedNonConstWeightsNoBias()
343*89c4ff92SAndroid Build Coastguard Worker             : FullyConnectedNonConstWeightsFixture("[ 1, 4, 1, 1 ]",     // inputShape
344*89c4ff92SAndroid Build Coastguard Worker                                                    "[ 1, 1 ]",           // outputShape
345*89c4ff92SAndroid Build Coastguard Worker                                                    "[ 1, 4 ]")           // filterShape
346*89c4ff92SAndroid Build Coastguard Worker 
347*89c4ff92SAndroid Build Coastguard Worker     {}
348*89c4ff92SAndroid Build Coastguard Worker };
349*89c4ff92SAndroid Build Coastguard Worker 
350*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FullyConnectedNonConstWeightsNoBias, "ParseFullyConnectedNonConstWeightsNoBias")
351*89c4ff92SAndroid Build Coastguard Worker {
352*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmS8,
353*89c4ff92SAndroid Build Coastguard Worker             armnn::DataType::QAsymmS8>(
354*89c4ff92SAndroid Build Coastguard Worker             0,
355*89c4ff92SAndroid Build Coastguard Worker             {{{"input_0", { 1, 2, 3, 4 }},{"weights", { 2, 3, 4, 5 }}}},
356*89c4ff92SAndroid Build Coastguard Worker             {{"output", { 20 }}});
357*89c4ff92SAndroid Build Coastguard Worker }
358*89c4ff92SAndroid Build Coastguard Worker 
359*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedWeightsBiasFloat : FullyConnectedFixture
360*89c4ff92SAndroid Build Coastguard Worker {
FullyConnectedWeightsBiasFloatFullyConnectedWeightsBiasFloat361*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedWeightsBiasFloat()
362*89c4ff92SAndroid Build Coastguard Worker             : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
363*89c4ff92SAndroid Build Coastguard Worker                                     "[ 1, 1, 1, 1 ]",     // outputShape
364*89c4ff92SAndroid Build Coastguard Worker                                     "[ 1, 4 ]",           // filterShape
365*89c4ff92SAndroid Build Coastguard Worker                                     "[ 2, 3, 4, 5 ]",     // filterData
366*89c4ff92SAndroid Build Coastguard Worker                                     "[ 1 ]",              // biasShape
367*89c4ff92SAndroid Build Coastguard Worker                                     "[ 10, 0, 0, 0 ]",    // filterShape
368*89c4ff92SAndroid Build Coastguard Worker                                     "FLOAT32",            // input and output dataType
369*89c4ff92SAndroid Build Coastguard Worker                                     "INT8",               // weights dataType
370*89c4ff92SAndroid Build Coastguard Worker                                     "FLOAT32")            // bias dataType
371*89c4ff92SAndroid Build Coastguard Worker     {}
372*89c4ff92SAndroid Build Coastguard Worker };
373*89c4ff92SAndroid Build Coastguard Worker 
374*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FullyConnectedWeightsBiasFloat, "FullyConnectedWeightsBiasFloat")
375*89c4ff92SAndroid Build Coastguard Worker {
376*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(
377*89c4ff92SAndroid Build Coastguard Worker             0,
378*89c4ff92SAndroid Build Coastguard Worker             { 10, 20, 30, 40 },
379*89c4ff92SAndroid Build Coastguard Worker             { 400 });
380*89c4ff92SAndroid Build Coastguard Worker }
381*89c4ff92SAndroid Build Coastguard Worker 
382*89c4ff92SAndroid Build Coastguard Worker }
383