1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li // Note: the ArmnnFencedExecutionCallback and code snippet in the executeFenced() function
6*3e777be0SXin Li // in this file is based on Android code
7*3e777be0SXin Li // under the Apache 2.0 license. See comments below for details.
8*3e777be0SXin Li //
9*3e777be0SXin Li
10*3e777be0SXin Li #define LOG_TAG "ArmnnDriver"
11*3e777be0SXin Li
12*3e777be0SXin Li #include "ArmnnPreparedModel_1_3.hpp"
13*3e777be0SXin Li #include "Utils.hpp"
14*3e777be0SXin Li
15*3e777be0SXin Li #include <armnn/Types.hpp>
16*3e777be0SXin Li
17*3e777be0SXin Li #include <Utils.h>
18*3e777be0SXin Li #include <android/sync.h>
19*3e777be0SXin Li #include <log/log.h>
20*3e777be0SXin Li #include <OperationsUtils.h>
21*3e777be0SXin Li #include <ExecutionBurstServer.h>
22*3e777be0SXin Li #include <ValidateHal.h>
23*3e777be0SXin Li
24*3e777be0SXin Li #include <chrono>
25*3e777be0SXin Li #include <cinttypes>
26*3e777be0SXin Li
27*3e777be0SXin Li #ifdef ARMNN_ANDROID_S
28*3e777be0SXin Li #include <LegacyUtils.h>
29*3e777be0SXin Li #endif
30*3e777be0SXin Li
31*3e777be0SXin Li using namespace android;
32*3e777be0SXin Li using namespace android::hardware;
33*3e777be0SXin Li
34*3e777be0SXin Li namespace {
35*3e777be0SXin Li
36*3e777be0SXin Li static const V1_2::Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
37*3e777be0SXin Li using namespace armnn_driver;
38*3e777be0SXin Li using TimePoint = std::chrono::steady_clock::time_point;
39*3e777be0SXin Li
Now()40*3e777be0SXin Li TimePoint Now()
41*3e777be0SXin Li {
42*3e777be0SXin Li return std::chrono::steady_clock::now();
43*3e777be0SXin Li }
44*3e777be0SXin Li
MicrosecondsDuration(TimePoint endPoint,TimePoint startPoint)45*3e777be0SXin Li unsigned long MicrosecondsDuration(TimePoint endPoint, TimePoint startPoint)
46*3e777be0SXin Li {
47*3e777be0SXin Li return static_cast<unsigned long>(std::chrono::duration_cast<std::chrono::microseconds>(
48*3e777be0SXin Li endPoint - startPoint).count());
49*3e777be0SXin Li }
50*3e777be0SXin Li
NotifyCallbackAndCheck(const::android::sp<V1_0::IExecutionCallback> & callback,V1_3::ErrorStatus errorStatus,std::vector<V1_2::OutputShape>,const V1_2::Timing,std::string callingFunction)51*3e777be0SXin Li void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback,
52*3e777be0SXin Li V1_3::ErrorStatus errorStatus,
53*3e777be0SXin Li std::vector<V1_2::OutputShape>,
54*3e777be0SXin Li const V1_2::Timing,
55*3e777be0SXin Li std::string callingFunction)
56*3e777be0SXin Li {
57*3e777be0SXin Li Return<void> returned = callback->notify(convertToV1_0(errorStatus));
58*3e777be0SXin Li // This check is required, if the callback fails and it isn't checked it will bring down the service
59*3e777be0SXin Li if (!returned.isOk())
60*3e777be0SXin Li {
61*3e777be0SXin Li ALOGE("ArmnnDriver::%s: hidl callback failed to return properly: %s",
62*3e777be0SXin Li callingFunction.c_str(), returned.description().c_str());
63*3e777be0SXin Li }
64*3e777be0SXin Li }
65*3e777be0SXin Li
NotifyCallbackAndCheck(const::android::sp<V1_2::IExecutionCallback> & callback,V1_3::ErrorStatus errorStatus,std::vector<V1_2::OutputShape> outputShapes,const V1_2::Timing timing,std::string callingFunction)66*3e777be0SXin Li void NotifyCallbackAndCheck(const ::android::sp<V1_2::IExecutionCallback>& callback,
67*3e777be0SXin Li V1_3::ErrorStatus errorStatus,
68*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
69*3e777be0SXin Li const V1_2::Timing timing,
70*3e777be0SXin Li std::string callingFunction)
71*3e777be0SXin Li {
72*3e777be0SXin Li Return<void> returned = callback->notify_1_2(convertToV1_0(errorStatus), outputShapes, timing);
73*3e777be0SXin Li // This check is required, if the callback fails and it isn't checked it will bring down the service
74*3e777be0SXin Li if (!returned.isOk())
75*3e777be0SXin Li {
76*3e777be0SXin Li ALOGE("ArmnnDriver::%s: hidl callback failed to return properly: %s",
77*3e777be0SXin Li callingFunction.c_str(), returned.description().c_str());
78*3e777be0SXin Li }
79*3e777be0SXin Li }
80*3e777be0SXin Li
NotifyCallbackAndCheck(const::android::sp<V1_3::IExecutionCallback> & callback,V1_3::ErrorStatus errorStatus,std::vector<V1_2::OutputShape> outputShapes,const V1_2::Timing timing,std::string callingFunction)81*3e777be0SXin Li void NotifyCallbackAndCheck(const ::android::sp<V1_3::IExecutionCallback>& callback,
82*3e777be0SXin Li V1_3::ErrorStatus errorStatus,
83*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
84*3e777be0SXin Li const V1_2::Timing timing,
85*3e777be0SXin Li std::string callingFunction)
86*3e777be0SXin Li {
87*3e777be0SXin Li Return<void> returned = callback->notify_1_3(errorStatus, outputShapes, timing);
88*3e777be0SXin Li // This check is required, if the callback fails and it isn't checked it will bring down the service
89*3e777be0SXin Li if (!returned.isOk())
90*3e777be0SXin Li {
91*3e777be0SXin Li ALOGE("ArmnnDriver::%s: hidl callback failed to return properly: %s",
92*3e777be0SXin Li callingFunction.c_str(), returned.description().c_str());
93*3e777be0SXin Li }
94*3e777be0SXin Li }
95*3e777be0SXin Li
ValidateRequestArgument(const V1_0::RequestArgument & requestArg,const armnn::TensorInfo & tensorInfo)96*3e777be0SXin Li bool ValidateRequestArgument(const V1_0::RequestArgument& requestArg, const armnn::TensorInfo& tensorInfo)
97*3e777be0SXin Li {
98*3e777be0SXin Li if (requestArg.dimensions.size() != 0)
99*3e777be0SXin Li {
100*3e777be0SXin Li if (requestArg.dimensions.size() != tensorInfo.GetNumDimensions())
101*3e777be0SXin Li {
102*3e777be0SXin Li ALOGE("Mismatched dimensions (request argument: %zu, expected: %u)",
103*3e777be0SXin Li requestArg.dimensions.size(), tensorInfo.GetNumDimensions());
104*3e777be0SXin Li return false;
105*3e777be0SXin Li }
106*3e777be0SXin Li
107*3e777be0SXin Li for (unsigned int d = 0; d < tensorInfo.GetNumDimensions(); ++d)
108*3e777be0SXin Li {
109*3e777be0SXin Li if (requestArg.dimensions[d] != 0 && requestArg.dimensions[d] != tensorInfo.GetShape()[d])
110*3e777be0SXin Li {
111*3e777be0SXin Li ALOGE("Mismatched size for dimension %d (request argument: %u, expected %u)",
112*3e777be0SXin Li d, requestArg.dimensions[d], tensorInfo.GetShape()[d]);
113*3e777be0SXin Li return false;
114*3e777be0SXin Li }
115*3e777be0SXin Li }
116*3e777be0SXin Li }
117*3e777be0SXin Li
118*3e777be0SXin Li return true;
119*3e777be0SXin Li }
120*3e777be0SXin Li
GetTensorForRequestArgument(const V1_0::RequestArgument & requestArg,const armnn::TensorInfo & tensorInfo,const std::vector<::android::nn::RunTimePoolInfo> & requestPools)121*3e777be0SXin Li armnn::Tensor GetTensorForRequestArgument(const V1_0::RequestArgument& requestArg,
122*3e777be0SXin Li const armnn::TensorInfo& tensorInfo,
123*3e777be0SXin Li const std::vector<::android::nn::RunTimePoolInfo>& requestPools)
124*3e777be0SXin Li {
125*3e777be0SXin Li if (!ValidateRequestArgument(requestArg, tensorInfo))
126*3e777be0SXin Li {
127*3e777be0SXin Li return armnn::Tensor();
128*3e777be0SXin Li }
129*3e777be0SXin Li
130*3e777be0SXin Li return armnn::Tensor(tensorInfo, GetMemoryFromPool(requestArg.location, requestPools));
131*3e777be0SXin Li }
132*3e777be0SXin Li
BuildTensorName(const char * tensorNamePrefix,std::size_t index)133*3e777be0SXin Li inline std::string BuildTensorName(const char* tensorNamePrefix, std::size_t index)
134*3e777be0SXin Li {
135*3e777be0SXin Li return tensorNamePrefix + std::to_string(index);
136*3e777be0SXin Li }
137*3e777be0SXin Li
138*3e777be0SXin Li } // anonymous namespace
139*3e777be0SXin Li
140*3e777be0SXin Li using namespace android::hardware;
141*3e777be0SXin Li
142*3e777be0SXin Li namespace armnn_driver
143*3e777be0SXin Li {
144*3e777be0SXin Li
145*3e777be0SXin Li template<typename HalVersion>
146*3e777be0SXin Li RequestThread_1_3<ArmnnPreparedModel_1_3, HalVersion, CallbackContext_1_3>
147*3e777be0SXin Li ArmnnPreparedModel_1_3<HalVersion>::m_RequestThread;
148*3e777be0SXin Li
149*3e777be0SXin Li template<typename HalVersion>
150*3e777be0SXin Li std::unique_ptr<armnn::Threadpool> ArmnnPreparedModel_1_3<HalVersion>::m_Threadpool(nullptr);
151*3e777be0SXin Li
152*3e777be0SXin Li template<typename HalVersion>
153*3e777be0SXin Li template<typename TensorBindingCollection>
DumpTensorsIfRequired(char const * tensorNamePrefix,const TensorBindingCollection & tensorBindings)154*3e777be0SXin Li void ArmnnPreparedModel_1_3<HalVersion>::DumpTensorsIfRequired(char const* tensorNamePrefix,
155*3e777be0SXin Li const TensorBindingCollection& tensorBindings)
156*3e777be0SXin Li {
157*3e777be0SXin Li if (!m_RequestInputsAndOutputsDumpDir.empty())
158*3e777be0SXin Li {
159*3e777be0SXin Li const std::string requestName = std::to_string(m_NetworkId) + "_" + std::to_string(m_RequestCount) + ".dump";
160*3e777be0SXin Li for (std::size_t i = 0u; i < tensorBindings.size(); ++i)
161*3e777be0SXin Li {
162*3e777be0SXin Li DumpTensor(m_RequestInputsAndOutputsDumpDir,
163*3e777be0SXin Li requestName,
164*3e777be0SXin Li BuildTensorName(tensorNamePrefix, i),
165*3e777be0SXin Li tensorBindings[i].second);
166*3e777be0SXin Li }
167*3e777be0SXin Li }
168*3e777be0SXin Li }
169*3e777be0SXin Li
170*3e777be0SXin Li template<typename HalVersion>
ArmnnPreparedModel_1_3(armnn::NetworkId networkId,armnn::IRuntime * runtime,const V1_3::Model & model,const std::string & requestInputsAndOutputsDumpDir,const bool gpuProfilingEnabled,V1_3::Priority priority,const bool asyncModelExecutionEnabled,const unsigned int numberOfThreads,const bool importEnabled,const bool exportEnabled)171*3e777be0SXin Li ArmnnPreparedModel_1_3<HalVersion>::ArmnnPreparedModel_1_3(armnn::NetworkId networkId,
172*3e777be0SXin Li armnn::IRuntime* runtime,
173*3e777be0SXin Li const V1_3::Model& model,
174*3e777be0SXin Li const std::string& requestInputsAndOutputsDumpDir,
175*3e777be0SXin Li const bool gpuProfilingEnabled,
176*3e777be0SXin Li V1_3::Priority priority,
177*3e777be0SXin Li const bool asyncModelExecutionEnabled,
178*3e777be0SXin Li const unsigned int numberOfThreads,
179*3e777be0SXin Li const bool importEnabled,
180*3e777be0SXin Li const bool exportEnabled)
181*3e777be0SXin Li : m_NetworkId(networkId)
182*3e777be0SXin Li , m_Runtime(runtime)
183*3e777be0SXin Li , m_Model(model)
184*3e777be0SXin Li , m_RequestCount(0)
185*3e777be0SXin Li , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
186*3e777be0SXin Li , m_GpuProfilingEnabled(gpuProfilingEnabled)
187*3e777be0SXin Li , m_ModelPriority(priority)
188*3e777be0SXin Li , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
189*3e777be0SXin Li , m_EnableImport(importEnabled)
190*3e777be0SXin Li , m_EnableExport(exportEnabled)
191*3e777be0SXin Li , m_PreparedFromCache(false)
192*3e777be0SXin Li {
193*3e777be0SXin Li // Enable profiling if required.
194*3e777be0SXin Li m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
195*3e777be0SXin Li
196*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
197*3e777be0SXin Li {
198*3e777be0SXin Li std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles;
199*3e777be0SXin Li for (unsigned int i=0; i < numberOfThreads; ++i)
200*3e777be0SXin Li {
201*3e777be0SXin Li memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(networkId));
202*3e777be0SXin Li }
203*3e777be0SXin Li
204*3e777be0SXin Li if (!m_Threadpool)
205*3e777be0SXin Li {
206*3e777be0SXin Li m_Threadpool = std::make_unique<armnn::Threadpool>(numberOfThreads, runtime, memHandles);
207*3e777be0SXin Li }
208*3e777be0SXin Li else
209*3e777be0SXin Li {
210*3e777be0SXin Li m_Threadpool->LoadMemHandles(memHandles);
211*3e777be0SXin Li }
212*3e777be0SXin Li
213*3e777be0SXin Li m_WorkingMemHandle = memHandles.back();
214*3e777be0SXin Li }
215*3e777be0SXin Li }
216*3e777be0SXin Li
217*3e777be0SXin Li template<typename HalVersion>
ArmnnPreparedModel_1_3(armnn::NetworkId networkId,armnn::IRuntime * runtime,const std::string & requestInputsAndOutputsDumpDir,const bool gpuProfilingEnabled,V1_3::Priority priority,const bool asyncModelExecutionEnabled,const unsigned int numberOfThreads,const bool importEnabled,const bool exportEnabled,const bool preparedFromCache)218*3e777be0SXin Li ArmnnPreparedModel_1_3<HalVersion>::ArmnnPreparedModel_1_3(armnn::NetworkId networkId,
219*3e777be0SXin Li armnn::IRuntime* runtime,
220*3e777be0SXin Li const std::string& requestInputsAndOutputsDumpDir,
221*3e777be0SXin Li const bool gpuProfilingEnabled,
222*3e777be0SXin Li V1_3::Priority priority,
223*3e777be0SXin Li const bool asyncModelExecutionEnabled,
224*3e777be0SXin Li const unsigned int numberOfThreads,
225*3e777be0SXin Li const bool importEnabled,
226*3e777be0SXin Li const bool exportEnabled,
227*3e777be0SXin Li const bool preparedFromCache)
228*3e777be0SXin Li : m_NetworkId(networkId)
229*3e777be0SXin Li , m_Runtime(runtime)
230*3e777be0SXin Li , m_RequestCount(0)
231*3e777be0SXin Li , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
232*3e777be0SXin Li , m_GpuProfilingEnabled(gpuProfilingEnabled)
233*3e777be0SXin Li , m_ModelPriority(priority)
234*3e777be0SXin Li , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
235*3e777be0SXin Li , m_EnableImport(importEnabled)
236*3e777be0SXin Li , m_EnableExport(exportEnabled)
237*3e777be0SXin Li , m_PreparedFromCache(preparedFromCache)
238*3e777be0SXin Li {
239*3e777be0SXin Li // Enable profiling if required.
240*3e777be0SXin Li m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
241*3e777be0SXin Li
242*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
243*3e777be0SXin Li {
244*3e777be0SXin Li std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles;
245*3e777be0SXin Li for (unsigned int i=0; i < numberOfThreads; ++i)
246*3e777be0SXin Li {
247*3e777be0SXin Li memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(networkId));
248*3e777be0SXin Li }
249*3e777be0SXin Li
250*3e777be0SXin Li if (!m_Threadpool)
251*3e777be0SXin Li {
252*3e777be0SXin Li m_Threadpool = std::make_unique<armnn::Threadpool>(numberOfThreads, runtime, memHandles);
253*3e777be0SXin Li }
254*3e777be0SXin Li else
255*3e777be0SXin Li {
256*3e777be0SXin Li m_Threadpool->LoadMemHandles(memHandles);
257*3e777be0SXin Li }
258*3e777be0SXin Li
259*3e777be0SXin Li m_WorkingMemHandle = memHandles.back();
260*3e777be0SXin Li }
261*3e777be0SXin Li }
262*3e777be0SXin Li
263*3e777be0SXin Li template<typename HalVersion>
~ArmnnPreparedModel_1_3()264*3e777be0SXin Li ArmnnPreparedModel_1_3<HalVersion>::~ArmnnPreparedModel_1_3()
265*3e777be0SXin Li {
266*3e777be0SXin Li // Get a hold of the profiler used by this model.
267*3e777be0SXin Li std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkId);
268*3e777be0SXin Li if (profiler && m_GpuProfilingEnabled)
269*3e777be0SXin Li {
270*3e777be0SXin Li // Dump the profiling info to a file if required.
271*3e777be0SXin Li DumpJsonProfilingIfRequired(m_GpuProfilingEnabled, m_RequestInputsAndOutputsDumpDir, m_NetworkId,
272*3e777be0SXin Li profiler.get());
273*3e777be0SXin Li }
274*3e777be0SXin Li
275*3e777be0SXin Li // Unload the network associated with this model.
276*3e777be0SXin Li m_Runtime->UnloadNetwork(m_NetworkId);
277*3e777be0SXin Li
278*3e777be0SXin Li // Unload the network memhandles from the threadpool
279*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
280*3e777be0SXin Li {
281*3e777be0SXin Li m_Threadpool->UnloadMemHandles(m_NetworkId);
282*3e777be0SXin Li }
283*3e777be0SXin Li }
284*3e777be0SXin Li
285*3e777be0SXin Li template<typename HalVersion>
execute(const V1_0::Request & request,const::android::sp<V1_0::IExecutionCallback> & callback)286*3e777be0SXin Li Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::execute(const V1_0::Request& request,
287*3e777be0SXin Li const ::android::sp<V1_0::IExecutionCallback>& callback)
288*3e777be0SXin Li {
289*3e777be0SXin Li if (callback.get() == nullptr)
290*3e777be0SXin Li {
291*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_3::execute invalid callback passed");
292*3e777be0SXin Li return V1_0::ErrorStatus::INVALID_ARGUMENT;
293*3e777be0SXin Li }
294*3e777be0SXin Li
295*3e777be0SXin Li auto cb = [callback](V1_3::ErrorStatus errorStatus,
296*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
297*3e777be0SXin Li const V1_2::Timing& timing,
298*3e777be0SXin Li std::string callingFunction)
299*3e777be0SXin Li {
300*3e777be0SXin Li NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
301*3e777be0SXin Li };
302*3e777be0SXin Li
303*3e777be0SXin Li
304*3e777be0SXin Li return convertToV1_0(Execute(convertToV1_3(request), V1_2::MeasureTiming::NO, cb));
305*3e777be0SXin Li }
306*3e777be0SXin Li
307*3e777be0SXin Li template<typename HalVersion>
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measureTiming,const sp<V1_2::IExecutionCallback> & callback)308*3e777be0SXin Li Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::execute_1_2(
309*3e777be0SXin Li const V1_0::Request& request,
310*3e777be0SXin Li V1_2::MeasureTiming measureTiming,
311*3e777be0SXin Li const sp<V1_2::IExecutionCallback>& callback)
312*3e777be0SXin Li {
313*3e777be0SXin Li if (callback.get() == nullptr)
314*3e777be0SXin Li {
315*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_3::execute_1_2 invalid callback passed");
316*3e777be0SXin Li return V1_0::ErrorStatus::INVALID_ARGUMENT;
317*3e777be0SXin Li }
318*3e777be0SXin Li
319*3e777be0SXin Li auto cb = [callback](V1_3::ErrorStatus errorStatus,
320*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
321*3e777be0SXin Li const V1_2::Timing& timing,
322*3e777be0SXin Li std::string callingFunction)
323*3e777be0SXin Li {
324*3e777be0SXin Li NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
325*3e777be0SXin Li };
326*3e777be0SXin Li
327*3e777be0SXin Li return convertToV1_0(Execute(convertToV1_3(request), measureTiming, cb));
328*3e777be0SXin Li }
329*3e777be0SXin Li
330*3e777be0SXin Li template<typename HalVersion>
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,const sp<V1_3::IExecutionCallback> & callback)331*3e777be0SXin Li Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::execute_1_3(
332*3e777be0SXin Li const V1_3::Request& request,
333*3e777be0SXin Li V1_2::MeasureTiming measureTiming,
334*3e777be0SXin Li const V1_3::OptionalTimePoint&,
335*3e777be0SXin Li const V1_3::OptionalTimeoutDuration&,
336*3e777be0SXin Li const sp<V1_3::IExecutionCallback>& callback)
337*3e777be0SXin Li {
338*3e777be0SXin Li if (callback.get() == nullptr)
339*3e777be0SXin Li {
340*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_3::execute_1_3 invalid callback passed");
341*3e777be0SXin Li return V1_3::ErrorStatus::INVALID_ARGUMENT;
342*3e777be0SXin Li }
343*3e777be0SXin Li
344*3e777be0SXin Li auto cb = [callback](V1_3::ErrorStatus errorStatus,
345*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
346*3e777be0SXin Li const V1_2::Timing& timing,
347*3e777be0SXin Li std::string callingFunction)
348*3e777be0SXin Li {
349*3e777be0SXin Li NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
350*3e777be0SXin Li };
351*3e777be0SXin Li
352*3e777be0SXin Li return Execute(request, measureTiming, cb);
353*3e777be0SXin Li }
354*3e777be0SXin Li
355*3e777be0SXin Li /// This class is inspired by the sample implementation in Android named SampleFencedExecutionCallback.
356*3e777be0SXin Li /// The original code is licensed under Apache-2.0 and can be found at the following link:
357*3e777be0SXin Li /// https://android.googlesource.com/platform/frameworks/ml/+/master/nn/driver/sample/SampleDriver.h
358*3e777be0SXin Li class ArmnnFencedExecutionCallback : public V1_3::IFencedExecutionCallback
359*3e777be0SXin Li {
360*3e777be0SXin Li public:
ArmnnFencedExecutionCallback(V1_3::ErrorStatus errorStatus,V1_2::Timing timing,V1_2::Timing fenceTiming)361*3e777be0SXin Li ArmnnFencedExecutionCallback(V1_3::ErrorStatus errorStatus, V1_2::Timing timing, V1_2::Timing fenceTiming)
362*3e777be0SXin Li : m_ErrorStatus(errorStatus), m_Timing(timing), m_FenceTiming(fenceTiming) {}
~ArmnnFencedExecutionCallback()363*3e777be0SXin Li ~ArmnnFencedExecutionCallback() {}
364*3e777be0SXin Li
getExecutionInfo(getExecutionInfo_cb callback)365*3e777be0SXin Li Return<void> getExecutionInfo(getExecutionInfo_cb callback) override
366*3e777be0SXin Li {
367*3e777be0SXin Li callback(m_ErrorStatus, m_Timing, m_FenceTiming);
368*3e777be0SXin Li return Void();
369*3e777be0SXin Li }
370*3e777be0SXin Li private:
371*3e777be0SXin Li V1_3::ErrorStatus m_ErrorStatus;
372*3e777be0SXin Li V1_2::Timing m_Timing;
373*3e777be0SXin Li V1_2::Timing m_FenceTiming;
374*3e777be0SXin Li };
375*3e777be0SXin Li
376*3e777be0SXin Li template<typename HalVersion>
executeFenced(const V1_3::Request & request,const hidl_vec<hidl_handle> & fenceWaitFor,V1_2::MeasureTiming measureTiming,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration &,executeFenced_cb cb)377*3e777be0SXin Li Return<void> ArmnnPreparedModel_1_3<HalVersion>::executeFenced(const V1_3::Request& request,
378*3e777be0SXin Li const hidl_vec<hidl_handle>& fenceWaitFor,
379*3e777be0SXin Li V1_2::MeasureTiming measureTiming,
380*3e777be0SXin Li const V1_3::OptionalTimePoint& deadline,
381*3e777be0SXin Li const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
382*3e777be0SXin Li const V1_3::OptionalTimeoutDuration&,
383*3e777be0SXin Li executeFenced_cb cb)
384*3e777be0SXin Li {
385*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::executeFenced(...)");
386*3e777be0SXin Li if (cb == nullptr)
387*3e777be0SXin Li {
388*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_3::executeFenced invalid callback passed");
389*3e777be0SXin Li cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
390*3e777be0SXin Li return Void();
391*3e777be0SXin Li }
392*3e777be0SXin Li
393*3e777be0SXin Li if (deadline.getDiscriminator() != V1_3::OptionalTimePoint::hidl_discriminator::none)
394*3e777be0SXin Li {
395*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_3::executeFenced parameter deadline is set but not supported.");
396*3e777be0SXin Li }
397*3e777be0SXin Li
398*3e777be0SXin Li if (loopTimeoutDuration.getDiscriminator() != V1_3::OptionalTimeoutDuration::hidl_discriminator::none)
399*3e777be0SXin Li {
400*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_3::executeFenced parameter loopTimeoutDuration is set but not supported.");
401*3e777be0SXin Li }
402*3e777be0SXin Li
403*3e777be0SXin Li if (!m_PreparedFromCache && !android::nn::validateRequest(request, m_Model, /*allowUnspecifiedOutput=*/false))
404*3e777be0SXin Li {
405*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::executeFenced outputs must be specified for fenced execution ");
406*3e777be0SXin Li cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
407*3e777be0SXin Li return Void();
408*3e777be0SXin Li }
409*3e777be0SXin Li
410*3e777be0SXin Li ExecutionContext_1_3 ctx;
411*3e777be0SXin Li if (measureTiming == V1_2::MeasureTiming::YES)
412*3e777be0SXin Li {
413*3e777be0SXin Li ctx.measureTimings = measureTiming;
414*3e777be0SXin Li ctx.driverStart = Now();
415*3e777be0SXin Li }
416*3e777be0SXin Li
417*3e777be0SXin Li if (!m_PreparedFromCache)
418*3e777be0SXin Li {
419*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::executeFenced(): %s", GetModelSummary(m_Model).c_str());
420*3e777be0SXin Li }
421*3e777be0SXin Li m_RequestCount++;
422*3e777be0SXin Li
423*3e777be0SXin Li if (!m_RequestInputsAndOutputsDumpDir.empty())
424*3e777be0SXin Li {
425*3e777be0SXin Li ALOGD("Dumping inputs and outputs for request %" PRIuPTR, reinterpret_cast<std::uintptr_t>(&cb));
426*3e777be0SXin Li }
427*3e777be0SXin Li
428*3e777be0SXin Li // This code snippet is inspired by the sample implementation in Android named SampleDriver::executeFenced()
429*3e777be0SXin Li // function. The original code is licensed under Apache-2.0 and can be found at the following link:
430*3e777be0SXin Li // https://android.googlesource.com/platform/frameworks/ml/+/master/nn/driver/sample/SampleDriver.cpp
431*3e777be0SXin Li const auto fenceSize = fenceWaitFor.size();
432*3e777be0SXin Li for (unsigned int index = 0; index < fenceSize; ++index)
433*3e777be0SXin Li {
434*3e777be0SXin Li auto fenceNativeHandle = fenceWaitFor[index].getNativeHandle();
435*3e777be0SXin Li if (!fenceNativeHandle)
436*3e777be0SXin Li {
437*3e777be0SXin Li cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
438*3e777be0SXin Li return Void();
439*3e777be0SXin Li }
440*3e777be0SXin Li
441*3e777be0SXin Li if (fenceNativeHandle->numFds != 1)
442*3e777be0SXin Li {
443*3e777be0SXin Li cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
444*3e777be0SXin Li return Void();
445*3e777be0SXin Li }
446*3e777be0SXin Li
447*3e777be0SXin Li if (sync_wait(fenceNativeHandle->data[0], -1) < 0)
448*3e777be0SXin Li {
449*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_3::executeFenced sync fence failed.");
450*3e777be0SXin Li cb(V1_3::ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
451*3e777be0SXin Li return Void();
452*3e777be0SXin Li }
453*3e777be0SXin Li }
454*3e777be0SXin Li
455*3e777be0SXin Li TimePoint fenceExecutionStart;
456*3e777be0SXin Li if (measureTiming == V1_2::MeasureTiming::YES)
457*3e777be0SXin Li {
458*3e777be0SXin Li fenceExecutionStart = Now();
459*3e777be0SXin Li }
460*3e777be0SXin Li
461*3e777be0SXin Li // map the memory pool into shared pointers
462*3e777be0SXin Li // use a shared memory pools vector on the heap, as it is passed to the request thread
463*3e777be0SXin Li auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
464*3e777be0SXin Li
465*3e777be0SXin Li // allocate the tensors on the heap, as they are passed to the request thread
466*3e777be0SXin Li auto inputs = std::make_shared<armnn::InputTensors>();
467*3e777be0SXin Li auto outputs = std::make_shared<armnn::OutputTensors>();
468*3e777be0SXin Li
469*3e777be0SXin Li auto [status, outShapes, timings, message] = PrepareMemoryForIO(*inputs, *outputs, *memPools, request);
470*3e777be0SXin Li if (status != V1_3::ErrorStatus::NONE)
471*3e777be0SXin Li {
472*3e777be0SXin Li cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
473*3e777be0SXin Li return Void();
474*3e777be0SXin Li }
475*3e777be0SXin Li
476*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::executeFenced(...) before ExecuteGraph");
477*3e777be0SXin Li
478*3e777be0SXin Li // call it with nullCallback for now as we will report the error status from here..
479*3e777be0SXin Li auto nullCallback = [](V1_3::ErrorStatus, std::vector<V1_2::OutputShape>, const V1_2::Timing&, std::string) {};
480*3e777be0SXin Li CallbackContext_1_3 cbCtx;
481*3e777be0SXin Li cbCtx.callback = nullCallback;
482*3e777be0SXin Li cbCtx.ctx = ctx;
483*3e777be0SXin Li
484*3e777be0SXin Li auto errorStatus = ExecuteGraph(memPools, *inputs, *outputs, cbCtx);
485*3e777be0SXin Li if (errorStatus != V1_3::ErrorStatus::NONE)
486*3e777be0SXin Li {
487*3e777be0SXin Li cb(errorStatus, hidl_handle(nullptr), nullptr);
488*3e777be0SXin Li return Void();
489*3e777be0SXin Li }
490*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::executeFenced(...) after ExecuteGraph");
491*3e777be0SXin Li
492*3e777be0SXin Li V1_2::Timing timing = g_NoTiming;
493*3e777be0SXin Li V1_2::Timing fenceTiming = g_NoTiming;
494*3e777be0SXin Li if (measureTiming == V1_2::MeasureTiming::YES)
495*3e777be0SXin Li {
496*3e777be0SXin Li fenceTiming.timeOnDevice = MicrosecondsDuration(ctx.deviceEnd, ctx.deviceStart);
497*3e777be0SXin Li fenceTiming.timeInDriver = MicrosecondsDuration(ctx.driverEnd, fenceExecutionStart);
498*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::fenceFinishExecutionTiming - Device = %lu Driver = %lu",
499*3e777be0SXin Li static_cast<unsigned long>(fenceTiming.timeOnDevice),
500*3e777be0SXin Li static_cast<unsigned long>(fenceTiming.timeInDriver));
501*3e777be0SXin Li }
502*3e777be0SXin Li
503*3e777be0SXin Li sp<ArmnnFencedExecutionCallback> armnnFencedExecutionCallback =
504*3e777be0SXin Li new ArmnnFencedExecutionCallback(V1_3::ErrorStatus::NONE, timing, fenceTiming);
505*3e777be0SXin Li cb(V1_3::ErrorStatus::NONE, hidl_handle(nullptr), armnnFencedExecutionCallback);
506*3e777be0SXin Li return Void();
507*3e777be0SXin Li }
508*3e777be0SXin Li
509*3e777be0SXin Li template<typename HalVersion>
PrepareMemoryForInputs(armnn::InputTensors & inputs,const V1_3::Request & request,const std::vector<android::nn::RunTimePoolInfo> & memPools)510*3e777be0SXin Li Return<V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::PrepareMemoryForInputs(
511*3e777be0SXin Li armnn::InputTensors& inputs,
512*3e777be0SXin Li const V1_3::Request& request,
513*3e777be0SXin Li const std::vector<android::nn::RunTimePoolInfo>& memPools)
514*3e777be0SXin Li {
515*3e777be0SXin Li inputs.reserve(request.inputs.size());
516*3e777be0SXin Li for (unsigned int i = 0; i < request.inputs.size(); i++)
517*3e777be0SXin Li {
518*3e777be0SXin Li const auto& inputArg = request.inputs[i];
519*3e777be0SXin Li armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
520*3e777be0SXin Li // inputs (of type InputTensors) is composed of a vector of ConstTensors.
521*3e777be0SXin Li // Therefore, set all TensorInfo isConstant parameters of input Tensors to true.
522*3e777be0SXin Li inputTensorInfo.SetConstant();
523*3e777be0SXin Li auto result = ValidateRequestArgument<V1_3::ErrorStatus, V1_3::Request>(request,
524*3e777be0SXin Li inputTensorInfo,
525*3e777be0SXin Li inputArg,
526*3e777be0SXin Li "input");
527*3e777be0SXin Li
528*3e777be0SXin Li if (result != V1_3::ErrorStatus::NONE)
529*3e777be0SXin Li {
530*3e777be0SXin Li return result;
531*3e777be0SXin Li }
532*3e777be0SXin Li
533*3e777be0SXin Li const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, memPools);
534*3e777be0SXin Li
535*3e777be0SXin Li if (inputTensor.GetMemoryArea() == nullptr)
536*3e777be0SXin Li {
537*3e777be0SXin Li ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
538*3e777be0SXin Li return V1_3::ErrorStatus::GENERAL_FAILURE;
539*3e777be0SXin Li }
540*3e777be0SXin Li
541*3e777be0SXin Li inputs.emplace_back(i, inputTensor);
542*3e777be0SXin Li }
543*3e777be0SXin Li
544*3e777be0SXin Li return V1_3::ErrorStatus::NONE;
545*3e777be0SXin Li }
546*3e777be0SXin Li
547*3e777be0SXin Li template<typename HalVersion>
PrepareMemoryForOutputs(armnn::OutputTensors & outputs,std::vector<V1_2::OutputShape> & outputShapes,const V1_3::Request & request,const std::vector<android::nn::RunTimePoolInfo> & memPools)548*3e777be0SXin Li Return<V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::PrepareMemoryForOutputs(
549*3e777be0SXin Li armnn::OutputTensors& outputs,
550*3e777be0SXin Li std::vector<V1_2::OutputShape> &outputShapes,
551*3e777be0SXin Li const V1_3::Request& request,
552*3e777be0SXin Li const std::vector<android::nn::RunTimePoolInfo>& memPools)
553*3e777be0SXin Li {
554*3e777be0SXin Li outputs.reserve(request.outputs.size());
555*3e777be0SXin Li for (unsigned int i = 0; i < request.outputs.size(); i++)
556*3e777be0SXin Li {
557*3e777be0SXin Li const auto& outputArg = request.outputs[i];
558*3e777be0SXin Li armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
559*3e777be0SXin Li auto result = ValidateRequestArgument<V1_3::ErrorStatus, V1_3::Request>(request,
560*3e777be0SXin Li outputTensorInfo,
561*3e777be0SXin Li outputArg,
562*3e777be0SXin Li "output");
563*3e777be0SXin Li
564*3e777be0SXin Li if (result != V1_3::ErrorStatus::NONE)
565*3e777be0SXin Li {
566*3e777be0SXin Li return result;
567*3e777be0SXin Li }
568*3e777be0SXin Li
569*3e777be0SXin Li const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, memPools);
570*3e777be0SXin Li
571*3e777be0SXin Li if (outputTensor.GetMemoryArea() == nullptr)
572*3e777be0SXin Li {
573*3e777be0SXin Li ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
574*3e777be0SXin Li return V1_3::ErrorStatus::GENERAL_FAILURE;
575*3e777be0SXin Li }
576*3e777be0SXin Li const size_t outputSize = outputTensorInfo.GetNumBytes();
577*3e777be0SXin Li
578*3e777be0SXin Li unsigned int count = 0;
579*3e777be0SXin Li std::for_each(outputArg.dimensions.begin(), outputArg.dimensions.end(), [&](auto dim)
580*3e777be0SXin Li {
581*3e777be0SXin Li if (dim != 0)
582*3e777be0SXin Li {
583*3e777be0SXin Li outputTensorInfo.GetShape()[count] = dim;
584*3e777be0SXin Li }
585*3e777be0SXin Li else
586*3e777be0SXin Li {
587*3e777be0SXin Li outputTensorInfo.GetShape()[count] = outputArg.dimensions.size();
588*3e777be0SXin Li }
589*3e777be0SXin Li
590*3e777be0SXin Li count++;
591*3e777be0SXin Li });
592*3e777be0SXin Li
593*3e777be0SXin Li outputs.emplace_back(i, outputTensor);
594*3e777be0SXin Li outputShapes[i] = ComputeShape(outputTensorInfo);
595*3e777be0SXin Li
596*3e777be0SXin Li if (outputArg.location.length < outputSize)
597*3e777be0SXin Li {
598*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_3::Execute failed outputArg.location.length (%s) < outputSize (%s)",
599*3e777be0SXin Li std::to_string(outputArg.location.length).c_str(), std::to_string(outputSize).c_str());
600*3e777be0SXin Li outputShapes[i].isSufficient = false;
601*3e777be0SXin Li return V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
602*3e777be0SXin Li }
603*3e777be0SXin Li
604*3e777be0SXin Li size_t bufferSize = 0;
605*3e777be0SXin Li #if !defined(ARMNN_ANDROID_S)
606*3e777be0SXin Li bufferSize = memPools.at(outputArg.location.poolIndex).getHidlMemory().size();
607*3e777be0SXin Li #else
608*3e777be0SXin Li bufferSize = memPools.at(outputArg.location.poolIndex).getSize();
609*3e777be0SXin Li #endif
610*3e777be0SXin Li if (bufferSize < outputSize)
611*3e777be0SXin Li {
612*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_3::Execute failed bufferSize (%s) < outputSize (%s)",
613*3e777be0SXin Li std::to_string(bufferSize).c_str(), std::to_string(outputSize).c_str());
614*3e777be0SXin Li outputShapes[i].isSufficient = false;
615*3e777be0SXin Li return V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
616*3e777be0SXin Li }
617*3e777be0SXin Li }
618*3e777be0SXin Li
619*3e777be0SXin Li return V1_3::ErrorStatus::NONE;
620*3e777be0SXin Li }
621*3e777be0SXin Li
622*3e777be0SXin Li template<typename HalVersion>
623*3e777be0SXin Li std::tuple<V1_3::ErrorStatus, hidl_vec<V1_2::OutputShape>, V1_2::Timing, std::string>
PrepareMemoryForIO(armnn::InputTensors & inputs,armnn::OutputTensors & outputs,std::vector<android::nn::RunTimePoolInfo> & memPools,const V1_3::Request & request)624*3e777be0SXin Li ArmnnPreparedModel_1_3<HalVersion>::PrepareMemoryForIO(armnn::InputTensors& inputs,
625*3e777be0SXin Li armnn::OutputTensors& outputs,
626*3e777be0SXin Li std::vector<android::nn::RunTimePoolInfo>& memPools,
627*3e777be0SXin Li const V1_3::Request& request)
628*3e777be0SXin Li {
629*3e777be0SXin Li #if !defined(ARMNN_ANDROID_S)
630*3e777be0SXin Li if (!setRunTimePoolInfosFromMemoryPools(&memPools, request.pools))
631*3e777be0SXin Li #else
632*3e777be0SXin Li if (!setRunTimePoolInfosFromMemoryPools(&memPools, uncheckedConvert(request.pools)))
633*3e777be0SXin Li #endif
634*3e777be0SXin Li {
635*3e777be0SXin Li return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
636*3e777be0SXin Li }
637*3e777be0SXin Li
638*3e777be0SXin Li // add the inputs and outputs with their data
639*3e777be0SXin Li try
640*3e777be0SXin Li {
641*3e777be0SXin Li if (PrepareMemoryForInputs(inputs, request, memPools) != V1_3::ErrorStatus::NONE)
642*3e777be0SXin Li {
643*3e777be0SXin Li return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
644*3e777be0SXin Li }
645*3e777be0SXin Li
646*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes(request.outputs.size());
647*3e777be0SXin Li
648*3e777be0SXin Li auto errorStatus = PrepareMemoryForOutputs(outputs, outputShapes, request, memPools);
649*3e777be0SXin Li if (errorStatus != V1_3::ErrorStatus::NONE)
650*3e777be0SXin Li {
651*3e777be0SXin Li return {errorStatus, outputShapes, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
652*3e777be0SXin Li }
653*3e777be0SXin Li }
654*3e777be0SXin Li catch (armnn::Exception& e)
655*3e777be0SXin Li {
656*3e777be0SXin Li ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
657*3e777be0SXin Li return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
658*3e777be0SXin Li }
659*3e777be0SXin Li catch (std::exception& e)
660*3e777be0SXin Li {
661*3e777be0SXin Li ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
662*3e777be0SXin Li return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
663*3e777be0SXin Li }
664*3e777be0SXin Li
665*3e777be0SXin Li return {V1_3::ErrorStatus::NONE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute"};
666*3e777be0SXin Li }
667*3e777be0SXin Li
668*3e777be0SXin Li template<typename HalVersion>
669*3e777be0SXin Li template<typename CallbackContext>
ExecuteSynchronously(const V1_3::Request & request,CallbackContext cbCtx)670*3e777be0SXin Li Return<void> ArmnnPreparedModel_1_3<HalVersion>::ExecuteSynchronously(const V1_3::Request& request,
671*3e777be0SXin Li CallbackContext cbCtx)
672*3e777be0SXin Li {
673*3e777be0SXin Li if (cbCtx.ctx.measureTimings == V1_2::MeasureTiming::YES)
674*3e777be0SXin Li {
675*3e777be0SXin Li cbCtx.ctx.driverStart = Now();
676*3e777be0SXin Li }
677*3e777be0SXin Li
678*3e777be0SXin Li if (!m_PreparedFromCache && !android::nn::validateRequest(convertToV1_3(request), m_Model))
679*3e777be0SXin Li {
680*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_3::ExecuteSynchronously invalid request model");
681*3e777be0SXin Li cbCtx.callback(V1_3::ErrorStatus::INVALID_ARGUMENT,
682*3e777be0SXin Li {},
683*3e777be0SXin Li g_NoTiming,
684*3e777be0SXin Li "ArmnnPreparedModel_1_3::ExecuteSynchronously invalid request model");
685*3e777be0SXin Li return Void();
686*3e777be0SXin Li }
687*3e777be0SXin Li
688*3e777be0SXin Li if (!m_PreparedFromCache && !android::nn::validateRequest(request, m_Model))
689*3e777be0SXin Li {
690*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_3::ExecuteSynchronously invalid request model");
691*3e777be0SXin Li cbCtx.callback(V1_3::ErrorStatus::INVALID_ARGUMENT,
692*3e777be0SXin Li {},
693*3e777be0SXin Li g_NoTiming,
694*3e777be0SXin Li "ArmnnPreparedModel_1_3::ExecuteSynchronously invalid request model");
695*3e777be0SXin Li return Void();
696*3e777be0SXin Li }
697*3e777be0SXin Li
698*3e777be0SXin Li
699*3e777be0SXin Li // map the memory pool into shared pointers
700*3e777be0SXin Li // use a shared memory pools vector on the heap, as it is passed to the request thread
701*3e777be0SXin Li auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
702*3e777be0SXin Li
703*3e777be0SXin Li // allocate the tensors on the heap, as they are passed to the request thread
704*3e777be0SXin Li auto inputs = std::make_shared<armnn::InputTensors>();
705*3e777be0SXin Li auto outputs = std::make_shared<armnn::OutputTensors>();
706*3e777be0SXin Li
707*3e777be0SXin Li auto [status, outputShapes, timing, message] = PrepareMemoryForIO(*inputs, *outputs, *memPools, request);
708*3e777be0SXin Li if (status != V1_3::ErrorStatus::NONE)
709*3e777be0SXin Li {
710*3e777be0SXin Li cbCtx.callback(status, outputShapes, timing, message);
711*3e777be0SXin Li return Void();
712*3e777be0SXin Li }
713*3e777be0SXin Li
714*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::ExecuteSynchronously() before Execution");
715*3e777be0SXin Li
716*3e777be0SXin Li ExecuteGraph(memPools, *inputs, *outputs, cbCtx);
717*3e777be0SXin Li return Void();
718*3e777be0SXin Li }
719*3e777be0SXin Li
720*3e777be0SXin Li template<typename HalVersion>
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measureTiming,executeSynchronously_cb cb)721*3e777be0SXin Li Return<void> ArmnnPreparedModel_1_3<HalVersion>::executeSynchronously(const V1_0::Request& request,
722*3e777be0SXin Li V1_2::MeasureTiming measureTiming,
723*3e777be0SXin Li executeSynchronously_cb cb)
724*3e777be0SXin Li {
725*3e777be0SXin Li if (!m_PreparedFromCache)
726*3e777be0SXin Li {
727*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::executeSynchronously(): %s", GetModelSummary(m_Model).c_str());
728*3e777be0SXin Li }
729*3e777be0SXin Li m_RequestCount++;
730*3e777be0SXin Li
731*3e777be0SXin Li if (cb == nullptr)
732*3e777be0SXin Li {
733*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_3::executeSynchronously invalid callback passed");
734*3e777be0SXin Li return Void();
735*3e777be0SXin Li }
736*3e777be0SXin Li
737*3e777be0SXin Li auto cbWrapper = [cb](V1_3::ErrorStatus errorStatus,
738*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
739*3e777be0SXin Li const V1_2::Timing& timing,
740*3e777be0SXin Li std::string)
741*3e777be0SXin Li {
742*3e777be0SXin Li cb(convertToV1_0(errorStatus), outputShapes, timing);
743*3e777be0SXin Li };
744*3e777be0SXin Li
745*3e777be0SXin Li CallbackContext_1_3 cbCtx;
746*3e777be0SXin Li cbCtx.callback = cbWrapper;
747*3e777be0SXin Li cbCtx.ctx.measureTimings = measureTiming;
748*3e777be0SXin Li
749*3e777be0SXin Li ExecuteSynchronously(convertToV1_3(request), cbCtx);
750*3e777be0SXin Li return Void();
751*3e777be0SXin Li }
752*3e777be0SXin Li
753*3e777be0SXin Li template<typename HalVersion>
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measureTiming,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)754*3e777be0SXin Li Return<void> ArmnnPreparedModel_1_3<HalVersion>::executeSynchronously_1_3(
755*3e777be0SXin Li const V1_3::Request& request,
756*3e777be0SXin Li V1_2::MeasureTiming measureTiming,
757*3e777be0SXin Li const V1_3::OptionalTimePoint& deadline,
758*3e777be0SXin Li const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
759*3e777be0SXin Li executeSynchronously_1_3_cb cb)
760*3e777be0SXin Li {
761*3e777be0SXin Li if (!m_PreparedFromCache)
762*3e777be0SXin Li {
763*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::executeSynchronously_1_3(): %s", GetModelSummary(m_Model).c_str());
764*3e777be0SXin Li }
765*3e777be0SXin Li m_RequestCount++;
766*3e777be0SXin Li
767*3e777be0SXin Li if (cb == nullptr)
768*3e777be0SXin Li {
769*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_3::executeSynchronously_1_3 invalid callback passed");
770*3e777be0SXin Li return Void();
771*3e777be0SXin Li }
772*3e777be0SXin Li
773*3e777be0SXin Li if (deadline.getDiscriminator() != V1_3::OptionalTimePoint::hidl_discriminator::none)
774*3e777be0SXin Li {
775*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_3::executeSynchronously_1_3 parameter deadline is set but not supported.");
776*3e777be0SXin Li }
777*3e777be0SXin Li
778*3e777be0SXin Li if (loopTimeoutDuration.getDiscriminator() != V1_3::OptionalTimeoutDuration::hidl_discriminator::none)
779*3e777be0SXin Li {
780*3e777be0SXin Li ALOGW(
781*3e777be0SXin Li "ArmnnPreparedModel_1_3::executeSynchronously_1_3 parameter loopTimeoutDuration is set but not supported.");
782*3e777be0SXin Li }
783*3e777be0SXin Li
784*3e777be0SXin Li auto cbWrapper = [cb](V1_3::ErrorStatus errorStatus,
785*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
786*3e777be0SXin Li const V1_2::Timing& timing,
787*3e777be0SXin Li std::string)
788*3e777be0SXin Li {
789*3e777be0SXin Li cb(errorStatus, outputShapes, timing);
790*3e777be0SXin Li };
791*3e777be0SXin Li
792*3e777be0SXin Li CallbackContext_1_3 cbCtx;
793*3e777be0SXin Li cbCtx.callback = cbWrapper;
794*3e777be0SXin Li cbCtx.ctx.measureTimings = measureTiming;
795*3e777be0SXin Li
796*3e777be0SXin Li ExecuteSynchronously(request, cbCtx);
797*3e777be0SXin Li return Void();
798*3e777be0SXin Li }
799*3e777be0SXin Li
800*3e777be0SXin Li template<typename HalVersion>
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,V1_3::IPreparedModel::configureExecutionBurst_cb cb)801*3e777be0SXin Li Return<void> ArmnnPreparedModel_1_3<HalVersion>::configureExecutionBurst(
802*3e777be0SXin Li const sp<V1_2::IBurstCallback>& callback,
803*3e777be0SXin Li const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
804*3e777be0SXin Li const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
805*3e777be0SXin Li V1_3::IPreparedModel::configureExecutionBurst_cb cb)
806*3e777be0SXin Li {
807*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::configureExecutionBurst");
808*3e777be0SXin Li const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(callback,
809*3e777be0SXin Li requestChannel,
810*3e777be0SXin Li resultChannel,
811*3e777be0SXin Li this);
812*3e777be0SXin Li
813*3e777be0SXin Li if (burst == nullptr)
814*3e777be0SXin Li {
815*3e777be0SXin Li cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
816*3e777be0SXin Li }
817*3e777be0SXin Li else
818*3e777be0SXin Li {
819*3e777be0SXin Li cb(V1_0::ErrorStatus::NONE, burst);
820*3e777be0SXin Li }
821*3e777be0SXin Li return Void();
822*3e777be0SXin Li }
823*3e777be0SXin Li
824*3e777be0SXin Li template<typename HalVersion>
825*3e777be0SXin Li template<typename CallbackContext>
ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,armnn::InputTensors & inputTensors,armnn::OutputTensors & outputTensors,CallbackContext cb)826*3e777be0SXin Li Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::ExecuteGraph(
827*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
828*3e777be0SXin Li armnn::InputTensors& inputTensors,
829*3e777be0SXin Li armnn::OutputTensors& outputTensors,
830*3e777be0SXin Li CallbackContext cb)
831*3e777be0SXin Li {
832*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::ExecuteGraph(...)");
833*3e777be0SXin Li // Capture the graph execution start time.
834*3e777be0SXin Li std::chrono::time_point<std::chrono::system_clock> graphExecutionStart = std::chrono::system_clock::now();
835*3e777be0SXin Li
836*3e777be0SXin Li DumpTensorsIfRequired("Input", inputTensors);
837*3e777be0SXin Li
838*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes(outputTensors.size());
839*3e777be0SXin Li for (unsigned int i = 0; i < outputTensors.size(); i++)
840*3e777be0SXin Li {
841*3e777be0SXin Li std::pair<int, armnn::Tensor> outputTensorPair = outputTensors[i];
842*3e777be0SXin Li const armnn::Tensor outputTensor = outputTensorPair.second;
843*3e777be0SXin Li const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
844*3e777be0SXin Li
845*3e777be0SXin Li outputShapes[i] = ComputeShape(outputTensorInfo);
846*3e777be0SXin Li }
847*3e777be0SXin Li
848*3e777be0SXin Li // run it
849*3e777be0SXin Li try
850*3e777be0SXin Li {
851*3e777be0SXin Li if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
852*3e777be0SXin Li {
853*3e777be0SXin Li cb.ctx.deviceStart = Now();
854*3e777be0SXin Li }
855*3e777be0SXin Li armnn::Status status;
856*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
857*3e777be0SXin Li {
858*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled true");
859*3e777be0SXin Li status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors);
860*3e777be0SXin Li }
861*3e777be0SXin Li else
862*3e777be0SXin Li {
863*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph m_AsyncModelExecutionEnabled false");
864*3e777be0SXin Li // Create a vector of Input and Output Ids which can be imported. An empty vector means all will be copied.
865*3e777be0SXin Li std::vector<armnn::ImportedInputId> importedInputIds;
866*3e777be0SXin Li if (m_EnableImport)
867*3e777be0SXin Li {
868*3e777be0SXin Li importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc);
869*3e777be0SXin Li }
870*3e777be0SXin Li std::vector<armnn::ImportedOutputId> importedOutputIds;
871*3e777be0SXin Li if (m_EnableExport)
872*3e777be0SXin Li {
873*3e777be0SXin Li importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc);
874*3e777be0SXin Li }
875*3e777be0SXin Li status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors,
876*3e777be0SXin Li importedInputIds, importedOutputIds);
877*3e777be0SXin Li }
878*3e777be0SXin Li
879*3e777be0SXin Li if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
880*3e777be0SXin Li {
881*3e777be0SXin Li cb.ctx.deviceEnd = Now();
882*3e777be0SXin Li }
883*3e777be0SXin Li if (status != armnn::Status::Success)
884*3e777be0SXin Li {
885*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_3::ExecuteGraph EnqueueWorkload failed");
886*3e777be0SXin Li cb.callback(V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph");
887*3e777be0SXin Li return V1_3::ErrorStatus::GENERAL_FAILURE;
888*3e777be0SXin Li }
889*3e777be0SXin Li }
890*3e777be0SXin Li catch (armnn::Exception& e)
891*3e777be0SXin Li {
892*3e777be0SXin Li ALOGW("armnn:Exception caught from EnqueueWorkload: %s", e.what());
893*3e777be0SXin Li cb.callback(V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph");
894*3e777be0SXin Li return V1_3::ErrorStatus::GENERAL_FAILURE;
895*3e777be0SXin Li }
896*3e777be0SXin Li catch (std::exception& e)
897*3e777be0SXin Li {
898*3e777be0SXin Li ALOGE("std::exception caught from EnqueueWorkload: %s", e.what());
899*3e777be0SXin Li cb.callback(V1_3::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph");
900*3e777be0SXin Li return V1_3::ErrorStatus::GENERAL_FAILURE;
901*3e777be0SXin Li }
902*3e777be0SXin Li
903*3e777be0SXin Li CommitPools(*pMemPools);
904*3e777be0SXin Li
905*3e777be0SXin Li DumpTensorsIfRequired("Output", outputTensors);
906*3e777be0SXin Li
907*3e777be0SXin Li if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
908*3e777be0SXin Li {
909*3e777be0SXin Li cb.ctx.driverEnd = Now();
910*3e777be0SXin Li V1_2::Timing timing;
911*3e777be0SXin Li timing.timeOnDevice = MicrosecondsDuration(cb.ctx.deviceEnd, cb.ctx.deviceStart);
912*3e777be0SXin Li timing.timeInDriver = MicrosecondsDuration(cb.ctx.driverEnd, cb.ctx.driverStart);
913*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::execute timing - Device = %lu Driver = %lu",
914*3e777be0SXin Li static_cast<unsigned long>(timing.timeOnDevice), static_cast<unsigned long>(timing.timeInDriver));
915*3e777be0SXin Li cb.callback(V1_3::ErrorStatus::NONE, outputShapes, timing, "ArmnnPreparedModel_1_3::ExecuteGraph");
916*3e777be0SXin Li } else
917*3e777be0SXin Li {
918*3e777be0SXin Li cb.callback(V1_3::ErrorStatus::NONE, outputShapes, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph");
919*3e777be0SXin Li }
920*3e777be0SXin Li // Log the total time in this call. This is a good number to compare to that printed out by
921*3e777be0SXin Li // RuntimeImpl::EnqueueWorkload. The difference should be the execution overhead of the driver.
922*3e777be0SXin Li ALOGI("ArmnnPreparedModel_1_3::ExecuteGraph Execution time = %lld µs",
923*3e777be0SXin Li std::chrono::duration_cast<std::chrono::microseconds>
924*3e777be0SXin Li (std::chrono::system_clock::now() - graphExecutionStart).count());
925*3e777be0SXin Li return V1_3::ErrorStatus::NONE;
926*3e777be0SXin Li }
927*3e777be0SXin Li
928*3e777be0SXin Li /// Schedule the graph prepared from the request for execution
929*3e777be0SXin Li template<typename HalVersion>
930*3e777be0SXin Li template<typename CallbackContext>
ScheduleGraphForExecution(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext,armnn::QosExecPriority priority)931*3e777be0SXin Li void ArmnnPreparedModel_1_3<HalVersion>::ScheduleGraphForExecution(
932*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
933*3e777be0SXin Li std::shared_ptr<armnn::InputTensors>& inputTensors,
934*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors>& outputTensors,
935*3e777be0SXin Li CallbackContext callbackContext,
936*3e777be0SXin Li armnn::QosExecPriority priority)
937*3e777be0SXin Li {
938*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution(...)");
939*3e777be0SXin Li
940*3e777be0SXin Li DumpTensorsIfRequired("Input", *inputTensors);
941*3e777be0SXin Li
942*3e777be0SXin Li unsigned int outputTensorSize = outputTensors.get()->size();
943*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes(outputTensorSize);
944*3e777be0SXin Li for (unsigned int i = 0; i < outputTensorSize; i++)
945*3e777be0SXin Li {
946*3e777be0SXin Li std::pair<int, armnn::Tensor> outputTensorPair = outputTensors.get()->at(i);
947*3e777be0SXin Li const armnn::Tensor outputTensor = outputTensorPair.second;
948*3e777be0SXin Li const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
949*3e777be0SXin Li
950*3e777be0SXin Li outputShapes[i] = ComputeShape(outputTensorInfo);
951*3e777be0SXin Li }
952*3e777be0SXin Li
953*3e777be0SXin Li auto tpCb = std::make_shared<
954*3e777be0SXin Li ArmnnThreadPoolCallback_1_3<CallbackContext_1_3>>(this,
955*3e777be0SXin Li pMemPools,
956*3e777be0SXin Li outputShapes,
957*3e777be0SXin Li inputTensors,
958*3e777be0SXin Li outputTensors,
959*3e777be0SXin Li callbackContext);
960*3e777be0SXin Li
961*3e777be0SXin Li m_Threadpool->Schedule(m_NetworkId,
962*3e777be0SXin Li *tpCb->m_InputTensors,
963*3e777be0SXin Li *tpCb->m_OutputTensors,
964*3e777be0SXin Li priority,
965*3e777be0SXin Li tpCb);
966*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::ScheduleGraphForExecution end");
967*3e777be0SXin Li }
968*3e777be0SXin Li
969*3e777be0SXin Li template<typename HalVersion>
ExecuteWithDummyInputs(unsigned int numInputs,unsigned int numOutputs)970*3e777be0SXin Li bool ArmnnPreparedModel_1_3<HalVersion>::ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs)
971*3e777be0SXin Li {
972*3e777be0SXin Li std::vector<std::vector<char>> storage;
973*3e777be0SXin Li armnn::InputTensors inputTensors;
974*3e777be0SXin Li for (unsigned int i = 0; i < numInputs; i++)
975*3e777be0SXin Li {
976*3e777be0SXin Li armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
977*3e777be0SXin Li // pInputTensors (of type InputTensors) is composed of a vector of ConstTensors.
978*3e777be0SXin Li // Therefore, set all TensorInfo isConstant parameters of input Tensors to true.
979*3e777be0SXin Li inputTensorInfo.SetConstant();
980*3e777be0SXin Li
981*3e777be0SXin Li storage.emplace_back(inputTensorInfo.GetNumBytes());
982*3e777be0SXin Li const armnn::ConstTensor inputTensor(inputTensorInfo, storage.back().data());
983*3e777be0SXin Li
984*3e777be0SXin Li inputTensors.emplace_back(i, inputTensor);
985*3e777be0SXin Li }
986*3e777be0SXin Li
987*3e777be0SXin Li armnn::OutputTensors outputTensors;
988*3e777be0SXin Li for (unsigned int i = 0; i < numOutputs; i++)
989*3e777be0SXin Li {
990*3e777be0SXin Li const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
991*3e777be0SXin Li storage.emplace_back(outputTensorInfo.GetNumBytes());
992*3e777be0SXin Li const armnn::Tensor outputTensor(outputTensorInfo, storage.back().data());
993*3e777be0SXin Li
994*3e777be0SXin Li outputTensors.emplace_back(i, outputTensor);
995*3e777be0SXin Li }
996*3e777be0SXin Li
997*3e777be0SXin Li auto nullCallback = [](V1_3::ErrorStatus, std::vector<V1_2::OutputShape>, const V1_2::Timing&, std::string) {};
998*3e777be0SXin Li CallbackContext_1_3 callbackContext;
999*3e777be0SXin Li callbackContext.callback = nullCallback;
1000*3e777be0SXin Li callbackContext.ctx.measureTimings = V1_2::MeasureTiming::NO;
1001*3e777be0SXin Li auto memPools = std::make_shared<std::vector<::android::nn::RunTimePoolInfo>>();
1002*3e777be0SXin Li
1003*3e777be0SXin Li auto errorStatus = ExecuteGraph(memPools,
1004*3e777be0SXin Li inputTensors,
1005*3e777be0SXin Li outputTensors,
1006*3e777be0SXin Li callbackContext);
1007*3e777be0SXin Li return errorStatus == V1_3::ErrorStatus::NONE;
1008*3e777be0SXin Li }
1009*3e777be0SXin Li
1010*3e777be0SXin Li template<typename HalVersion>
Execute(const V1_3::Request & request,V1_2::MeasureTiming measureTiming,CallbackAsync_1_3 callback)1011*3e777be0SXin Li Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<HalVersion>::Execute(const V1_3::Request& request,
1012*3e777be0SXin Li V1_2::MeasureTiming measureTiming,
1013*3e777be0SXin Li CallbackAsync_1_3 callback)
1014*3e777be0SXin Li {
1015*3e777be0SXin Li ExecutionContext_1_3 ctx;
1016*3e777be0SXin Li if (measureTiming == V1_2::MeasureTiming::YES)
1017*3e777be0SXin Li {
1018*3e777be0SXin Li ctx.measureTimings = measureTiming;
1019*3e777be0SXin Li ctx.driverStart = Now();
1020*3e777be0SXin Li }
1021*3e777be0SXin Li
1022*3e777be0SXin Li if (!m_PreparedFromCache)
1023*3e777be0SXin Li {
1024*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::execute(): %s", GetModelSummary(m_Model).c_str());
1025*3e777be0SXin Li }
1026*3e777be0SXin Li m_RequestCount++;
1027*3e777be0SXin Li
1028*3e777be0SXin Li if (!m_PreparedFromCache && !android::nn::validateRequest(request, m_Model))
1029*3e777be0SXin Li {
1030*3e777be0SXin Li callback(V1_3::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming, "ArmnnPreparedModel_1_3::execute");
1031*3e777be0SXin Li return V1_3::ErrorStatus::INVALID_ARGUMENT;
1032*3e777be0SXin Li }
1033*3e777be0SXin Li
1034*3e777be0SXin Li if (!m_RequestInputsAndOutputsDumpDir.empty())
1035*3e777be0SXin Li {
1036*3e777be0SXin Li ALOGD("Dumping inputs and outputs for request %" PRIuPTR, reinterpret_cast<std::uintptr_t>(&callback));
1037*3e777be0SXin Li }
1038*3e777be0SXin Li
1039*3e777be0SXin Li // map the memory pool into shared pointers
1040*3e777be0SXin Li // use a shared memory pools vector on the heap, as it is passed to the request thread
1041*3e777be0SXin Li auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
1042*3e777be0SXin Li
1043*3e777be0SXin Li // allocate the tensors on the heap, as they are passed to the request thread
1044*3e777be0SXin Li auto inputTensors = std::make_shared<armnn::InputTensors>();
1045*3e777be0SXin Li auto outputTensors = std::make_shared<armnn::OutputTensors>();
1046*3e777be0SXin Li
1047*3e777be0SXin Li auto [status, outShapes, timing, message] = PrepareMemoryForIO(*inputTensors, *outputTensors,
1048*3e777be0SXin Li *memPools, request);
1049*3e777be0SXin Li if (status != V1_3::ErrorStatus::NONE)
1050*3e777be0SXin Li {
1051*3e777be0SXin Li callback(status, outShapes, timing, message);
1052*3e777be0SXin Li }
1053*3e777be0SXin Li
1054*3e777be0SXin Li switch(status)
1055*3e777be0SXin Li {
1056*3e777be0SXin Li case V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
1057*3e777be0SXin Li return V1_3::ErrorStatus::NONE;
1058*3e777be0SXin Li case V1_3::ErrorStatus::GENERAL_FAILURE:
1059*3e777be0SXin Li return V1_3::ErrorStatus::GENERAL_FAILURE;
1060*3e777be0SXin Li case V1_3::ErrorStatus::INVALID_ARGUMENT:
1061*3e777be0SXin Li return V1_3::ErrorStatus::INVALID_ARGUMENT;
1062*3e777be0SXin Li default:
1063*3e777be0SXin Li {}
1064*3e777be0SXin Li }
1065*3e777be0SXin Li CallbackContext_1_3 cb;
1066*3e777be0SXin Li cb.callback = callback;
1067*3e777be0SXin Li cb.ctx = ctx;
1068*3e777be0SXin Li
1069*3e777be0SXin Li
1070*3e777be0SXin Li enum class QosExecPriority
1071*3e777be0SXin Li {
1072*3e777be0SXin Li Low = 0,
1073*3e777be0SXin Li Medium = 1,
1074*3e777be0SXin Li High = 2
1075*3e777be0SXin Li };
1076*3e777be0SXin Li
1077*3e777be0SXin Li
1078*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
1079*3e777be0SXin Li {
1080*3e777be0SXin Li armnn::QosExecPriority priority;
1081*3e777be0SXin Li
1082*3e777be0SXin Li switch (GetModelPriority()) {
1083*3e777be0SXin Li case V1_3::Priority::LOW:
1084*3e777be0SXin Li priority = armnn::QosExecPriority::Low;
1085*3e777be0SXin Li break;
1086*3e777be0SXin Li case V1_3::Priority::MEDIUM:
1087*3e777be0SXin Li priority = armnn::QosExecPriority::Medium;
1088*3e777be0SXin Li break;
1089*3e777be0SXin Li case V1_3::Priority::HIGH:
1090*3e777be0SXin Li priority = armnn::QosExecPriority::High;
1091*3e777be0SXin Li break;
1092*3e777be0SXin Li default:
1093*3e777be0SXin Li priority = armnn::QosExecPriority::Medium;
1094*3e777be0SXin Li
1095*3e777be0SXin Li }
1096*3e777be0SXin Li
1097*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::execute(...) before ScheduleGraphForExecution");
1098*3e777be0SXin Li ScheduleGraphForExecution(memPools, inputTensors, outputTensors, cb, priority);
1099*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::execute(...) after ScheduleGraphForExecution");
1100*3e777be0SXin Li return V1_3::ErrorStatus::NONE;
1101*3e777be0SXin Li }
1102*3e777be0SXin Li
1103*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::execute(...) before PostMsg");
1104*3e777be0SXin Li // post the request for asynchronous execution
1105*3e777be0SXin Li m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb);
1106*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::execute(...) after PostMsg");
1107*3e777be0SXin Li return V1_3::ErrorStatus::NONE;
1108*3e777be0SXin Li }
1109*3e777be0SXin Li
1110*3e777be0SXin Li template<typename HalVersion>
GetModelPriority()1111*3e777be0SXin Li V1_3::Priority ArmnnPreparedModel_1_3<HalVersion>::GetModelPriority()
1112*3e777be0SXin Li {
1113*3e777be0SXin Li return m_ModelPriority;
1114*3e777be0SXin Li }
1115*3e777be0SXin Li
1116*3e777be0SXin Li template<typename HalVersion>
1117*3e777be0SXin Li template <typename CallbackContext>
Notify(armnn::Status status,armnn::InferenceTimingPair timeTaken)1118*3e777be0SXin Li void ArmnnPreparedModel_1_3<HalVersion>::ArmnnThreadPoolCallback_1_3<CallbackContext>::Notify(
1119*3e777be0SXin Li armnn::Status status, armnn::InferenceTimingPair timeTaken)
1120*3e777be0SXin Li {
1121*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3<CallbackContext>::Notify");
1122*3e777be0SXin Li CommitPools(*m_MemPools);
1123*3e777be0SXin Li
1124*3e777be0SXin Li m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors);
1125*3e777be0SXin Li
1126*3e777be0SXin Li if (status != armnn::Status::Success)
1127*3e777be0SXin Li {
1128*3e777be0SXin Li ALOGW("ArmnnThreadPoolCallback_1_3::Notify EnqueueWorkload failed");
1129*3e777be0SXin Li m_CallbackContext.callback(V1_3::ErrorStatus::GENERAL_FAILURE,
1130*3e777be0SXin Li {},
1131*3e777be0SXin Li g_NoTiming,
1132*3e777be0SXin Li "ArmnnPreparedModel_1_3::ArmnnThreadPoolCallback_1_3");
1133*3e777be0SXin Li return;
1134*3e777be0SXin Li }
1135*3e777be0SXin Li
1136*3e777be0SXin Li if (m_CallbackContext.ctx.measureTimings == V1_2::MeasureTiming::YES)
1137*3e777be0SXin Li {
1138*3e777be0SXin Li m_CallbackContext.ctx.deviceStart = timeTaken.first;
1139*3e777be0SXin Li m_CallbackContext.ctx.deviceEnd = timeTaken.second;
1140*3e777be0SXin Li m_CallbackContext.ctx.driverEnd = std::chrono::steady_clock::now();
1141*3e777be0SXin Li V1_2::Timing timing;
1142*3e777be0SXin Li timing.timeOnDevice = MicrosecondsDuration(m_CallbackContext.ctx.deviceEnd, m_CallbackContext.ctx.deviceStart);
1143*3e777be0SXin Li timing.timeInDriver = MicrosecondsDuration(m_CallbackContext.ctx.driverEnd, m_CallbackContext.ctx.driverStart);
1144*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_3::execute timing - Device = %lu Driver = %lu",
1145*3e777be0SXin Li static_cast<unsigned long>(timing.timeOnDevice), static_cast<unsigned long>(timing.timeInDriver));
1146*3e777be0SXin Li m_CallbackContext.callback(
1147*3e777be0SXin Li V1_3::ErrorStatus::NONE, m_OutputShapes, timing, "ArmnnPreparedModel_1_3::ExecuteGraph");
1148*3e777be0SXin Li } else
1149*3e777be0SXin Li {
1150*3e777be0SXin Li m_CallbackContext.callback(
1151*3e777be0SXin Li V1_3::ErrorStatus::NONE, m_OutputShapes, g_NoTiming, "ArmnnPreparedModel_1_3::ExecuteGraph");
1152*3e777be0SXin Li }
1153*3e777be0SXin Li return;
1154*3e777be0SXin Li }
1155*3e777be0SXin Li
1156*3e777be0SXin Li #ifdef ARMNN_ANDROID_NN_V1_3
1157*3e777be0SXin Li template class ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>;
1158*3e777be0SXin Li template Return <V1_3::ErrorStatus> ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>::ExecuteGraph<CallbackContext_1_3>(
1159*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
1160*3e777be0SXin Li armnn::InputTensors& pInputTensors,
1161*3e777be0SXin Li armnn::OutputTensors& pOutputTensors,
1162*3e777be0SXin Li CallbackContext_1_3 cb);
1163*3e777be0SXin Li
1164*3e777be0SXin Li template void ArmnnPreparedModel_1_3<hal_1_3::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_3>(
1165*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
1166*3e777be0SXin Li std::shared_ptr<armnn::InputTensors>& inputTensors,
1167*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors>& outputTensors,
1168*3e777be0SXin Li CallbackContext_1_3 callbackContext,
1169*3e777be0SXin Li armnn::QosExecPriority priority);
1170*3e777be0SXin Li #endif
1171*3e777be0SXin Li
1172*3e777be0SXin Li } // namespace armnn_driver
1173