xref: /aosp_15_r20/external/android-nn-driver/test/FullyConnected.cpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #include "DriverTestHelpers.hpp"
7*3e777be0SXin Li 
8*3e777be0SXin Li #include <log/log.h>
9*3e777be0SXin Li 
10*3e777be0SXin Li DOCTEST_TEST_SUITE("FullyConnectedTests")
11*3e777be0SXin Li {
12*3e777be0SXin Li using namespace android::hardware;
13*3e777be0SXin Li using namespace driverTestHelpers;
14*3e777be0SXin Li using namespace armnn_driver;
15*3e777be0SXin Li 
16*3e777be0SXin Li using HalPolicy = hal_1_0::HalPolicy;
17*3e777be0SXin Li 
18*3e777be0SXin Li // Add our own test here since we fail the fc tests which Google supplies (because of non-const weights)
19*3e777be0SXin Li DOCTEST_TEST_CASE("FullyConnected")
20*3e777be0SXin Li {
21*3e777be0SXin Li     // this should ideally replicate fully_connected_float.model.cpp
22*3e777be0SXin Li     // but that uses slightly weird dimensions which I don't think we need to support for now
23*3e777be0SXin Li 
24*3e777be0SXin Li     auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
25*3e777be0SXin Li     HalPolicy::Model model = {};
26*3e777be0SXin Li 
27*3e777be0SXin Li     // add operands
28*3e777be0SXin Li     int32_t actValue      = 0;
29*3e777be0SXin Li     float   weightValue[] = {2, 4, 1};
30*3e777be0SXin Li     float   biasValue[]   = {4};
31*3e777be0SXin Li 
32*3e777be0SXin Li     AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 3});
33*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 3}, weightValue);
34*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1}, biasValue);
35*3e777be0SXin Li     AddIntOperand<HalPolicy>(model, actValue);
36*3e777be0SXin Li     AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1});
37*3e777be0SXin Li 
38*3e777be0SXin Li     // make the fully connected operation
39*3e777be0SXin Li     model.operations.resize(1);
40*3e777be0SXin Li     model.operations[0].type = HalPolicy::OperationType::FULLY_CONNECTED;
41*3e777be0SXin Li     model.operations[0].inputs  = hidl_vec<uint32_t>{0, 1, 2, 3};
42*3e777be0SXin Li     model.operations[0].outputs = hidl_vec<uint32_t>{4};
43*3e777be0SXin Li 
44*3e777be0SXin Li     // make the prepared model
45*3e777be0SXin Li     android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
46*3e777be0SXin Li 
47*3e777be0SXin Li     // construct the request
48*3e777be0SXin Li     V1_0::DataLocation inloc = {};
49*3e777be0SXin Li     inloc.poolIndex = 0;
50*3e777be0SXin Li     inloc.offset    = 0;
51*3e777be0SXin Li     inloc.length    = 3 * sizeof(float);
52*3e777be0SXin Li     RequestArgument input = {};
53*3e777be0SXin Li     input.location = inloc;
54*3e777be0SXin Li     input.dimensions = hidl_vec<uint32_t>{};
55*3e777be0SXin Li 
56*3e777be0SXin Li     V1_0::DataLocation outloc = {};
57*3e777be0SXin Li     outloc.poolIndex = 1;
58*3e777be0SXin Li     outloc.offset    = 0;
59*3e777be0SXin Li     outloc.length    = 1 * sizeof(float);
60*3e777be0SXin Li     RequestArgument output = {};
61*3e777be0SXin Li     output.location  = outloc;
62*3e777be0SXin Li     output.dimensions = hidl_vec<uint32_t>{};
63*3e777be0SXin Li 
64*3e777be0SXin Li     V1_0::Request request = {};
65*3e777be0SXin Li     request.inputs  = hidl_vec<RequestArgument>{input};
66*3e777be0SXin Li     request.outputs = hidl_vec<RequestArgument>{output};
67*3e777be0SXin Li 
68*3e777be0SXin Li     // set the input data (matching source test)
69*3e777be0SXin Li     float indata[] = {2, 32, 16};
70*3e777be0SXin Li     AddPoolAndSetData<float>(3, request, indata);
71*3e777be0SXin Li 
72*3e777be0SXin Li     // add memory for the output
73*3e777be0SXin Li     android::sp<IMemory> outMemory = AddPoolAndGetData<float>(1, request);
74*3e777be0SXin Li     float* outdata = static_cast<float*>(static_cast<void*>(outMemory->getPointer()));
75*3e777be0SXin Li 
76*3e777be0SXin Li     // run the execution
77*3e777be0SXin Li     if (preparedModel.get() != nullptr)
78*3e777be0SXin Li     {
79*3e777be0SXin Li         Execute(preparedModel, request);
80*3e777be0SXin Li     }
81*3e777be0SXin Li 
82*3e777be0SXin Li     // check the result
83*3e777be0SXin Li     DOCTEST_CHECK(outdata[0] == 152);
84*3e777be0SXin Li }
85*3e777be0SXin Li 
86*3e777be0SXin Li DOCTEST_TEST_CASE("TestFullyConnected4dInput")
87*3e777be0SXin Li {
88*3e777be0SXin Li     auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
89*3e777be0SXin Li 
90*3e777be0SXin Li     V1_0::ErrorStatus error;
91*3e777be0SXin Li     std::vector<bool> sup;
92*3e777be0SXin Li 
93*3e777be0SXin Li     ArmnnDriver::getSupportedOperations_cb cb = [&](V1_0::ErrorStatus status, const std::vector<bool>& supported)
__anon493742780102(V1_0::ErrorStatus status, const std::vector<bool>& supported) 94*3e777be0SXin Li         {
95*3e777be0SXin Li             error = status;
96*3e777be0SXin Li             sup = supported;
97*3e777be0SXin Li         };
98*3e777be0SXin Li 
99*3e777be0SXin Li     HalPolicy::Model model = {};
100*3e777be0SXin Li 
101*3e777be0SXin Li     // operands
102*3e777be0SXin Li     int32_t actValue      = 0;
103*3e777be0SXin Li     float   weightValue[] = {1, 0, 0, 0, 0, 0, 0, 0,
104*3e777be0SXin Li                              0, 1, 0, 0, 0, 0, 0, 0,
105*3e777be0SXin Li                              0, 0, 1, 0, 0, 0, 0, 0,
106*3e777be0SXin Li                              0, 0, 0, 1, 0, 0, 0, 0,
107*3e777be0SXin Li                              0, 0, 0, 0, 1, 0, 0, 0,
108*3e777be0SXin Li                              0, 0, 0, 0, 0, 1, 0, 0,
109*3e777be0SXin Li                              0, 0, 0, 0, 0, 0, 1, 0,
110*3e777be0SXin Li                              0, 0, 0, 0, 0, 0, 0, 1}; //identity
111*3e777be0SXin Li     float   biasValue[]   = {0, 0, 0, 0, 0, 0, 0, 0};
112*3e777be0SXin Li 
113*3e777be0SXin Li     // fully connected operation
114*3e777be0SXin Li     AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1, 1, 8});
115*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{8, 8}, weightValue);
116*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{8}, biasValue);
117*3e777be0SXin Li     AddIntOperand<HalPolicy>(model, actValue);
118*3e777be0SXin Li     AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 8});
119*3e777be0SXin Li 
120*3e777be0SXin Li     model.operations.resize(1);
121*3e777be0SXin Li 
122*3e777be0SXin Li     model.operations[0].type = HalPolicy::OperationType::FULLY_CONNECTED;
123*3e777be0SXin Li     model.operations[0].inputs  = hidl_vec<uint32_t>{0,1,2,3};
124*3e777be0SXin Li     model.operations[0].outputs = hidl_vec<uint32_t>{4};
125*3e777be0SXin Li 
126*3e777be0SXin Li     // make the prepared model
127*3e777be0SXin Li     android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
128*3e777be0SXin Li 
129*3e777be0SXin Li     // construct the request
130*3e777be0SXin Li     V1_0::DataLocation inloc = {};
131*3e777be0SXin Li     inloc.poolIndex          = 0;
132*3e777be0SXin Li     inloc.offset             = 0;
133*3e777be0SXin Li     inloc.length             = 8 * sizeof(float);
134*3e777be0SXin Li     RequestArgument input    = {};
135*3e777be0SXin Li     input.location           = inloc;
136*3e777be0SXin Li     input.dimensions         = hidl_vec<uint32_t>{};
137*3e777be0SXin Li 
138*3e777be0SXin Li     V1_0::DataLocation outloc = {};
139*3e777be0SXin Li     outloc.poolIndex          = 1;
140*3e777be0SXin Li     outloc.offset             = 0;
141*3e777be0SXin Li     outloc.length             = 8 * sizeof(float);
142*3e777be0SXin Li     RequestArgument output    = {};
143*3e777be0SXin Li     output.location           = outloc;
144*3e777be0SXin Li     output.dimensions         = hidl_vec<uint32_t>{};
145*3e777be0SXin Li 
146*3e777be0SXin Li     V1_0::Request request = {};
147*3e777be0SXin Li     request.inputs  = hidl_vec<RequestArgument>{input};
148*3e777be0SXin Li     request.outputs = hidl_vec<RequestArgument>{output};
149*3e777be0SXin Li 
150*3e777be0SXin Li     // set the input data
151*3e777be0SXin Li     float indata[] = {1,2,3,4,5,6,7,8};
152*3e777be0SXin Li     AddPoolAndSetData(8, request, indata);
153*3e777be0SXin Li 
154*3e777be0SXin Li     // add memory for the output
155*3e777be0SXin Li     android::sp<IMemory> outMemory = AddPoolAndGetData<float>(8, request);
156*3e777be0SXin Li     float* outdata = static_cast<float*>(static_cast<void*>(outMemory->getPointer()));
157*3e777be0SXin Li 
158*3e777be0SXin Li     // run the execution
159*3e777be0SXin Li     if (preparedModel != nullptr)
160*3e777be0SXin Li     {
161*3e777be0SXin Li         Execute(preparedModel, request);
162*3e777be0SXin Li     }
163*3e777be0SXin Li 
164*3e777be0SXin Li     // check the result
165*3e777be0SXin Li     DOCTEST_CHECK(outdata[0] == 1);
166*3e777be0SXin Li     DOCTEST_CHECK(outdata[1] == 2);
167*3e777be0SXin Li     DOCTEST_CHECK(outdata[2] == 3);
168*3e777be0SXin Li     DOCTEST_CHECK(outdata[3] == 4);
169*3e777be0SXin Li     DOCTEST_CHECK(outdata[4] == 5);
170*3e777be0SXin Li     DOCTEST_CHECK(outdata[5] == 6);
171*3e777be0SXin Li     DOCTEST_CHECK(outdata[6] == 7);
172*3e777be0SXin Li     DOCTEST_CHECK(outdata[7] == 8);
173*3e777be0SXin Li }
174*3e777be0SXin Li 
175*3e777be0SXin Li DOCTEST_TEST_CASE("TestFullyConnected4dInputReshape")
176*3e777be0SXin Li {
177*3e777be0SXin Li     auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
178*3e777be0SXin Li 
179*3e777be0SXin Li     V1_0::ErrorStatus error;
180*3e777be0SXin Li     std::vector<bool> sup;
181*3e777be0SXin Li 
182*3e777be0SXin Li     ArmnnDriver::getSupportedOperations_cb cb = [&](V1_0::ErrorStatus status, const std::vector<bool>& supported)
__anon493742780202(V1_0::ErrorStatus status, const std::vector<bool>& supported) 183*3e777be0SXin Li         {
184*3e777be0SXin Li             error = status;
185*3e777be0SXin Li             sup = supported;
186*3e777be0SXin Li         };
187*3e777be0SXin Li 
188*3e777be0SXin Li     HalPolicy::Model model = {};
189*3e777be0SXin Li 
190*3e777be0SXin Li     // operands
191*3e777be0SXin Li     int32_t actValue      = 0;
192*3e777be0SXin Li     float   weightValue[] = {1, 0, 0, 0, 0, 0, 0, 0,
193*3e777be0SXin Li                              0, 1, 0, 0, 0, 0, 0, 0,
194*3e777be0SXin Li                              0, 0, 1, 0, 0, 0, 0, 0,
195*3e777be0SXin Li                              0, 0, 0, 1, 0, 0, 0, 0,
196*3e777be0SXin Li                              0, 0, 0, 0, 1, 0, 0, 0,
197*3e777be0SXin Li                              0, 0, 0, 0, 0, 1, 0, 0,
198*3e777be0SXin Li                              0, 0, 0, 0, 0, 0, 1, 0,
199*3e777be0SXin Li                              0, 0, 0, 0, 0, 0, 0, 1}; //identity
200*3e777be0SXin Li     float   biasValue[]   = {0, 0, 0, 0, 0, 0, 0, 0};
201*3e777be0SXin Li 
202*3e777be0SXin Li     // fully connected operation
203*3e777be0SXin Li     AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 2, 2, 2});
204*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{8, 8}, weightValue);
205*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{8}, biasValue);
206*3e777be0SXin Li     AddIntOperand<HalPolicy>(model, actValue);
207*3e777be0SXin Li     AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 8});
208*3e777be0SXin Li 
209*3e777be0SXin Li     model.operations.resize(1);
210*3e777be0SXin Li 
211*3e777be0SXin Li     model.operations[0].type = HalPolicy::OperationType::FULLY_CONNECTED;
212*3e777be0SXin Li     model.operations[0].inputs  = hidl_vec<uint32_t>{0,1,2,3};
213*3e777be0SXin Li     model.operations[0].outputs = hidl_vec<uint32_t>{4};
214*3e777be0SXin Li 
215*3e777be0SXin Li     // make the prepared model
216*3e777be0SXin Li     android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
217*3e777be0SXin Li 
218*3e777be0SXin Li     // construct the request
219*3e777be0SXin Li     V1_0::DataLocation inloc = {};
220*3e777be0SXin Li     inloc.poolIndex          = 0;
221*3e777be0SXin Li     inloc.offset             = 0;
222*3e777be0SXin Li     inloc.length             = 8 * sizeof(float);
223*3e777be0SXin Li     RequestArgument input    = {};
224*3e777be0SXin Li     input.location           = inloc;
225*3e777be0SXin Li     input.dimensions         = hidl_vec<uint32_t>{};
226*3e777be0SXin Li 
227*3e777be0SXin Li     V1_0::DataLocation outloc = {};
228*3e777be0SXin Li     outloc.poolIndex          = 1;
229*3e777be0SXin Li     outloc.offset             = 0;
230*3e777be0SXin Li     outloc.length             = 8 * sizeof(float);
231*3e777be0SXin Li     RequestArgument output    = {};
232*3e777be0SXin Li     output.location           = outloc;
233*3e777be0SXin Li     output.dimensions         = hidl_vec<uint32_t>{};
234*3e777be0SXin Li 
235*3e777be0SXin Li     V1_0::Request request = {};
236*3e777be0SXin Li     request.inputs  = hidl_vec<RequestArgument>{input};
237*3e777be0SXin Li     request.outputs = hidl_vec<RequestArgument>{output};
238*3e777be0SXin Li 
239*3e777be0SXin Li     // set the input data
240*3e777be0SXin Li     float indata[] = {1,2,3,4,5,6,7,8};
241*3e777be0SXin Li     AddPoolAndSetData(8, request, indata);
242*3e777be0SXin Li 
243*3e777be0SXin Li     // add memory for the output
244*3e777be0SXin Li     android::sp<IMemory> outMemory = AddPoolAndGetData<float>(8, request);
245*3e777be0SXin Li     float* outdata = static_cast<float*>(static_cast<void*>(outMemory->getPointer()));
246*3e777be0SXin Li 
247*3e777be0SXin Li     // run the execution
248*3e777be0SXin Li     if (preparedModel != nullptr)
249*3e777be0SXin Li     {
250*3e777be0SXin Li         Execute(preparedModel, request);
251*3e777be0SXin Li     }
252*3e777be0SXin Li 
253*3e777be0SXin Li     // check the result
254*3e777be0SXin Li     DOCTEST_CHECK(outdata[0] == 1);
255*3e777be0SXin Li     DOCTEST_CHECK(outdata[1] == 2);
256*3e777be0SXin Li     DOCTEST_CHECK(outdata[2] == 3);
257*3e777be0SXin Li     DOCTEST_CHECK(outdata[3] == 4);
258*3e777be0SXin Li     DOCTEST_CHECK(outdata[4] == 5);
259*3e777be0SXin Li     DOCTEST_CHECK(outdata[5] == 6);
260*3e777be0SXin Li     DOCTEST_CHECK(outdata[6] == 7);
261*3e777be0SXin Li     DOCTEST_CHECK(outdata[7] == 8);
262*3e777be0SXin Li }
263*3e777be0SXin Li 
264*3e777be0SXin Li DOCTEST_TEST_CASE("TestFullyConnectedWeightsAsInput")
265*3e777be0SXin Li {
266*3e777be0SXin Li     auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
267*3e777be0SXin Li 
268*3e777be0SXin Li     V1_0::ErrorStatus error;
269*3e777be0SXin Li     std::vector<bool> sup;
270*3e777be0SXin Li 
271*3e777be0SXin Li     ArmnnDriver::getSupportedOperations_cb cb = [&](V1_0::ErrorStatus status, const std::vector<bool>& supported)
__anon493742780302(V1_0::ErrorStatus status, const std::vector<bool>& supported) 272*3e777be0SXin Li     {
273*3e777be0SXin Li         error = status;
274*3e777be0SXin Li         sup = supported;
275*3e777be0SXin Li     };
276*3e777be0SXin Li 
277*3e777be0SXin Li     HalPolicy::Model model = {};
278*3e777be0SXin Li 
279*3e777be0SXin Li     // operands
280*3e777be0SXin Li     int32_t actValue      = 0;
281*3e777be0SXin Li     float   weightValue[] = {1, 0, 0, 0, 0, 0, 0, 0,
282*3e777be0SXin Li                              0, 1, 0, 0, 0, 0, 0, 0,
283*3e777be0SXin Li                              0, 0, 1, 0, 0, 0, 0, 0,
284*3e777be0SXin Li                              0, 0, 0, 1, 0, 0, 0, 0,
285*3e777be0SXin Li                              0, 0, 0, 0, 1, 0, 0, 0,
286*3e777be0SXin Li                              0, 0, 0, 0, 0, 1, 0, 0,
287*3e777be0SXin Li                              0, 0, 0, 0, 0, 0, 1, 0,
288*3e777be0SXin Li                              0, 0, 0, 0, 0, 0, 0, 1}; //identity
289*3e777be0SXin Li     float   biasValue[]   = {0, 0, 0, 0, 0, 0, 0, 0};
290*3e777be0SXin Li 
291*3e777be0SXin Li     // fully connected operation
292*3e777be0SXin Li     AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1, 1, 8});
293*3e777be0SXin Li     AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{8, 8});
294*3e777be0SXin Li     AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{8});
295*3e777be0SXin Li     AddIntOperand<HalPolicy>(model, actValue);
296*3e777be0SXin Li     AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 8});
297*3e777be0SXin Li 
298*3e777be0SXin Li     model.operations.resize(1);
299*3e777be0SXin Li 
300*3e777be0SXin Li     model.operations[0].type = HalPolicy::OperationType::FULLY_CONNECTED;
301*3e777be0SXin Li     model.operations[0].inputs  = hidl_vec<uint32_t>{0,1,2,3};
302*3e777be0SXin Li     model.operations[0].outputs = hidl_vec<uint32_t>{4};
303*3e777be0SXin Li 
304*3e777be0SXin Li     // make the prepared model
305*3e777be0SXin Li     android::sp<V1_0::IPreparedModel> preparedModel = PrepareModel(model, *driver);
306*3e777be0SXin Li 
307*3e777be0SXin Li     // construct the request for input
308*3e777be0SXin Li     V1_0::DataLocation inloc = {};
309*3e777be0SXin Li     inloc.poolIndex          = 0;
310*3e777be0SXin Li     inloc.offset             = 0;
311*3e777be0SXin Li     inloc.length             = 8 * sizeof(float);
312*3e777be0SXin Li     RequestArgument input    = {};
313*3e777be0SXin Li     input.location           = inloc;
314*3e777be0SXin Li     input.dimensions         = hidl_vec<uint32_t>{1, 1, 1, 8};
315*3e777be0SXin Li 
316*3e777be0SXin Li     // construct the request for weights as input
317*3e777be0SXin Li     V1_0::DataLocation wloc = {};
318*3e777be0SXin Li     wloc.poolIndex          = 1;
319*3e777be0SXin Li     wloc.offset             = 0;
320*3e777be0SXin Li     wloc.length             = 64 * sizeof(float);
321*3e777be0SXin Li     RequestArgument weights = {};
322*3e777be0SXin Li     weights.location        = wloc;
323*3e777be0SXin Li     weights.dimensions      = hidl_vec<uint32_t>{8, 8};
324*3e777be0SXin Li 
325*3e777be0SXin Li     // construct the request for bias as input
326*3e777be0SXin Li     V1_0::DataLocation bloc = {};
327*3e777be0SXin Li     bloc.poolIndex          = 2;
328*3e777be0SXin Li     bloc.offset             = 0;
329*3e777be0SXin Li     bloc.length             = 8 * sizeof(float);
330*3e777be0SXin Li     RequestArgument bias    = {};
331*3e777be0SXin Li     bias.location           = bloc;
332*3e777be0SXin Li     bias.dimensions         = hidl_vec<uint32_t>{8};
333*3e777be0SXin Li 
334*3e777be0SXin Li     V1_0::DataLocation outloc = {};
335*3e777be0SXin Li     outloc.poolIndex          = 3;
336*3e777be0SXin Li     outloc.offset             = 0;
337*3e777be0SXin Li     outloc.length             = 8 * sizeof(float);
338*3e777be0SXin Li     RequestArgument output    = {};
339*3e777be0SXin Li     output.location           = outloc;
340*3e777be0SXin Li     output.dimensions         = hidl_vec<uint32_t>{1, 8};
341*3e777be0SXin Li 
342*3e777be0SXin Li     V1_0::Request request = {};
343*3e777be0SXin Li     request.inputs  = hidl_vec<RequestArgument>{input, weights, bias};
344*3e777be0SXin Li     request.outputs = hidl_vec<RequestArgument>{output};
345*3e777be0SXin Li 
346*3e777be0SXin Li     // set the input data
347*3e777be0SXin Li     float indata[] = {1,2,3,4,5,6,7,8};
348*3e777be0SXin Li     AddPoolAndSetData(8, request, indata);
349*3e777be0SXin Li 
350*3e777be0SXin Li     // set the weights data
351*3e777be0SXin Li     AddPoolAndSetData(64, request, weightValue);
352*3e777be0SXin Li     // set the bias data
353*3e777be0SXin Li     AddPoolAndSetData(8, request, biasValue);
354*3e777be0SXin Li 
355*3e777be0SXin Li     // add memory for the output
356*3e777be0SXin Li     android::sp<IMemory> outMemory = AddPoolAndGetData<float>(8, request);
357*3e777be0SXin Li     float* outdata = static_cast<float*>(static_cast<void*>(outMemory->getPointer()));
358*3e777be0SXin Li 
359*3e777be0SXin Li     // run the execution
360*3e777be0SXin Li     if (preparedModel != nullptr)
361*3e777be0SXin Li     {
362*3e777be0SXin Li         Execute(preparedModel, request);
363*3e777be0SXin Li     }
364*3e777be0SXin Li 
365*3e777be0SXin Li     // check the result
366*3e777be0SXin Li     DOCTEST_CHECK(outdata[0] == 1);
367*3e777be0SXin Li     DOCTEST_CHECK(outdata[1] == 2);
368*3e777be0SXin Li     DOCTEST_CHECK(outdata[2] == 3);
369*3e777be0SXin Li     DOCTEST_CHECK(outdata[3] == 4);
370*3e777be0SXin Li     DOCTEST_CHECK(outdata[4] == 5);
371*3e777be0SXin Li     DOCTEST_CHECK(outdata[5] == 6);
372*3e777be0SXin Li     DOCTEST_CHECK(outdata[6] == 7);
373*3e777be0SXin Li     DOCTEST_CHECK(outdata[7] == 8);
374*3e777be0SXin Li }
375*3e777be0SXin Li 
376*3e777be0SXin Li }
377