xref: /aosp_15_r20/external/android-nn-driver/test/Concurrent.cpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #include "DriverTestHelpers.hpp"
7*3e777be0SXin Li 
8*3e777be0SXin Li #include <log/log.h>
9*3e777be0SXin Li 
10*3e777be0SXin Li DOCTEST_TEST_SUITE("ConcurrentDriverTests")
11*3e777be0SXin Li {
12*3e777be0SXin Li using ArmnnDriver   = armnn_driver::ArmnnDriver;
13*3e777be0SXin Li using DriverOptions = armnn_driver::DriverOptions;
14*3e777be0SXin Li using HalPolicy     = armnn_driver::hal_1_0::HalPolicy;
15*3e777be0SXin Li using RequestArgument = V1_0::RequestArgument;
16*3e777be0SXin Li 
17*3e777be0SXin Li using namespace android::nn;
18*3e777be0SXin Li using namespace android::hardware;
19*3e777be0SXin Li using namespace driverTestHelpers;
20*3e777be0SXin Li using namespace armnn_driver;
21*3e777be0SXin Li 
22*3e777be0SXin Li // Add our own test for concurrent execution
23*3e777be0SXin Li // The main point of this test is to check that multiple requests can be
24*3e777be0SXin Li // executed without waiting for the callback from previous execution.
25*3e777be0SXin Li // The operations performed are not significant.
26*3e777be0SXin Li DOCTEST_TEST_CASE("ConcurrentExecute")
27*3e777be0SXin Li {
28*3e777be0SXin Li     ALOGI("ConcurrentExecute: entry");
29*3e777be0SXin Li 
30*3e777be0SXin Li     auto driver = std::make_unique<ArmnnDriver>(DriverOptions(armnn::Compute::CpuRef));
31*3e777be0SXin Li     HalPolicy::Model model = {};
32*3e777be0SXin Li 
33*3e777be0SXin Li     // add operands
34*3e777be0SXin Li     int32_t actValue      = 0;
35*3e777be0SXin Li     float   weightValue[] = {2, 4, 1};
36*3e777be0SXin Li     float   biasValue[]   = {4};
37*3e777be0SXin Li 
38*3e777be0SXin Li     AddInputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 3});
39*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 3}, weightValue);
40*3e777be0SXin Li     AddTensorOperand<HalPolicy>(model, hidl_vec<uint32_t>{1}, biasValue);
41*3e777be0SXin Li     AddIntOperand<HalPolicy>(model, actValue);
42*3e777be0SXin Li     AddOutputOperand<HalPolicy>(model, hidl_vec<uint32_t>{1, 1});
43*3e777be0SXin Li 
44*3e777be0SXin Li     // make the fully connected operation
45*3e777be0SXin Li     model.operations.resize(1);
46*3e777be0SXin Li     model.operations[0].type    = HalPolicy::OperationType::FULLY_CONNECTED;
47*3e777be0SXin Li     model.operations[0].inputs  = hidl_vec<uint32_t>{0, 1, 2, 3};
48*3e777be0SXin Li     model.operations[0].outputs = hidl_vec<uint32_t>{4};
49*3e777be0SXin Li 
50*3e777be0SXin Li     // make the prepared models
51*3e777be0SXin Li     const size_t maxRequests = 5;
52*3e777be0SXin Li     size_t preparedModelsSize = 0;
53*3e777be0SXin Li     android::sp<V1_0::IPreparedModel> preparedModels[maxRequests];
54*3e777be0SXin Li     for (size_t i = 0; i < maxRequests; ++i)
55*3e777be0SXin Li     {
56*3e777be0SXin Li         auto preparedModel = PrepareModel(model, *driver);
57*3e777be0SXin Li         if (preparedModel.get() != nullptr)
58*3e777be0SXin Li         {
59*3e777be0SXin Li             preparedModels[i] = PrepareModel(model, *driver);
60*3e777be0SXin Li             preparedModelsSize++;
61*3e777be0SXin Li         }
62*3e777be0SXin Li     }
63*3e777be0SXin Li 
64*3e777be0SXin Li     DOCTEST_CHECK(maxRequests == preparedModelsSize);
65*3e777be0SXin Li 
66*3e777be0SXin Li     // construct the request data
67*3e777be0SXin Li     V1_0::DataLocation inloc = {};
68*3e777be0SXin Li     inloc.poolIndex          = 0;
69*3e777be0SXin Li     inloc.offset             = 0;
70*3e777be0SXin Li     inloc.length             = 3 * sizeof(float);
71*3e777be0SXin Li     RequestArgument input    = {};
72*3e777be0SXin Li     input.location           = inloc;
73*3e777be0SXin Li     input.dimensions         = hidl_vec<uint32_t>{};
74*3e777be0SXin Li 
75*3e777be0SXin Li     V1_0::DataLocation outloc = {};
76*3e777be0SXin Li     outloc.poolIndex          = 1;
77*3e777be0SXin Li     outloc.offset             = 0;
78*3e777be0SXin Li     outloc.length             = 1 * sizeof(float);
79*3e777be0SXin Li     RequestArgument output    = {};
80*3e777be0SXin Li     output.location           = outloc;
81*3e777be0SXin Li     output.dimensions         = hidl_vec<uint32_t>{};
82*3e777be0SXin Li 
83*3e777be0SXin Li     // build the requests
84*3e777be0SXin Li     V1_0::Request requests[maxRequests];
85*3e777be0SXin Li     android::sp<IMemory> inMemory[maxRequests];
86*3e777be0SXin Li     android::sp<IMemory> outMemory[maxRequests];
87*3e777be0SXin Li     float indata[] = {2, 32, 16};
88*3e777be0SXin Li     float* outdata[maxRequests];
89*3e777be0SXin Li     for (size_t i = 0; i < maxRequests; ++i)
90*3e777be0SXin Li     {
91*3e777be0SXin Li         requests[i].inputs  = hidl_vec<RequestArgument>{input};
92*3e777be0SXin Li         requests[i].outputs = hidl_vec<RequestArgument>{output};
93*3e777be0SXin Li         // set the input data (matching source test)
94*3e777be0SXin Li         inMemory[i] = AddPoolAndSetData<float>(3, requests[i], indata);
95*3e777be0SXin Li         // add memory for the output
96*3e777be0SXin Li         outMemory[i] = AddPoolAndGetData<float>(1, requests[i]);
97*3e777be0SXin Li         outdata[i] = static_cast<float*>(static_cast<void*>(outMemory[i]->getPointer()));
98*3e777be0SXin Li     }
99*3e777be0SXin Li 
100*3e777be0SXin Li     // invoke the execution of the requests
101*3e777be0SXin Li     ALOGI("ConcurrentExecute: executing requests");
102*3e777be0SXin Li     android::sp<ExecutionCallback> cb[maxRequests];
103*3e777be0SXin Li     for (size_t i = 0; i < maxRequests; ++i)
104*3e777be0SXin Li     {
105*3e777be0SXin Li         cb[i] = ExecuteNoWait(preparedModels[i], requests[i]);
106*3e777be0SXin Li     }
107*3e777be0SXin Li 
108*3e777be0SXin Li     // wait for the requests to complete
109*3e777be0SXin Li     ALOGI("ConcurrentExecute: waiting for callbacks");
110*3e777be0SXin Li     for (size_t i = 0; i < maxRequests; ++i)
111*3e777be0SXin Li     {
112*3e777be0SXin Li         DOCTEST_CHECK(cb[i]);
113*3e777be0SXin Li         cb[i]->wait();
114*3e777be0SXin Li     }
115*3e777be0SXin Li 
116*3e777be0SXin Li     // check the results
117*3e777be0SXin Li     ALOGI("ConcurrentExecute: validating results");
118*3e777be0SXin Li     for (size_t i = 0; i < maxRequests; ++i)
119*3e777be0SXin Li     {
120*3e777be0SXin Li         DOCTEST_CHECK(outdata[i][0] == 152);
121*3e777be0SXin Li     }
122*3e777be0SXin Li     ALOGI("ConcurrentExecute: exit");
123*3e777be0SXin Li }
124*3e777be0SXin Li 
125*3e777be0SXin Li }
126