xref: /aosp_15_r20/external/android-nn-driver/ArmnnPreparedModel_1_3.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2020 Arm Ltd. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li 
8*3e777be0SXin Li #include "ArmnnDriver.hpp"
9*3e777be0SXin Li #include "ArmnnDriverImpl.hpp"
10*3e777be0SXin Li #include "RequestThread_1_3.hpp"
11*3e777be0SXin Li #include "ModelToINetworkConverter.hpp"
12*3e777be0SXin Li 
13*3e777be0SXin Li #include <NeuralNetworks.h>
14*3e777be0SXin Li #include <armnn/ArmNN.hpp>
15*3e777be0SXin Li #include <armnn/Threadpool.hpp>
16*3e777be0SXin Li 
17*3e777be0SXin Li 
18*3e777be0SXin Li #include <string>
19*3e777be0SXin Li #include <vector>
20*3e777be0SXin Li 
21*3e777be0SXin Li namespace armnn_driver
22*3e777be0SXin Li {
23*3e777be0SXin Li using CallbackAsync_1_3 = std::function<
24*3e777be0SXin Li                                 void(V1_3::ErrorStatus errorStatus,
25*3e777be0SXin Li                                 std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
26*3e777be0SXin Li                                 const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
27*3e777be0SXin Li                                 std::string callingFunction)>;
28*3e777be0SXin Li 
29*3e777be0SXin Li struct ExecutionContext_1_3
30*3e777be0SXin Li {
31*3e777be0SXin Li     ::android::hardware::neuralnetworks::V1_2::MeasureTiming    measureTimings =
32*3e777be0SXin Li         ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
33*3e777be0SXin Li     TimePoint driverStart;
34*3e777be0SXin Li     TimePoint driverEnd;
35*3e777be0SXin Li     TimePoint deviceStart;
36*3e777be0SXin Li     TimePoint deviceEnd;
37*3e777be0SXin Li };
38*3e777be0SXin Li 
39*3e777be0SXin Li using CallbackContext_1_3 = CallbackContext<CallbackAsync_1_3, ExecutionContext_1_3>;
40*3e777be0SXin Li 
41*3e777be0SXin Li using executeFenced_cb = std::function<void(::android::hardware::neuralnetworks::V1_3::ErrorStatus status,
42*3e777be0SXin Li     const ::android::hardware::hidl_handle& syncFence,
43*3e777be0SXin Li     const ::android::sp<::android::hardware::neuralnetworks::V1_3::IFencedExecutionCallback>& callback)>;
44*3e777be0SXin Li 
45*3e777be0SXin Li template <typename HalVersion>
46*3e777be0SXin Li class ArmnnPreparedModel_1_3 : public V1_3::IPreparedModel
47*3e777be0SXin Li {
48*3e777be0SXin Li public:
49*3e777be0SXin Li     using HalModel = typename V1_3::Model;
50*3e777be0SXin Li 
51*3e777be0SXin Li     ArmnnPreparedModel_1_3(armnn::NetworkId networkId,
52*3e777be0SXin Li                            armnn::IRuntime* runtime,
53*3e777be0SXin Li                            const HalModel& model,
54*3e777be0SXin Li                            const std::string& requestInputsAndOutputsDumpDir,
55*3e777be0SXin Li                            const bool gpuProfilingEnabled,
56*3e777be0SXin Li                            V1_3::Priority priority = V1_3::Priority::MEDIUM,
57*3e777be0SXin Li                            const bool asyncModelExecutionEnabled = false,
58*3e777be0SXin Li                            const unsigned int numberOfThreads = 1,
59*3e777be0SXin Li                            const bool importEnabled = false,
60*3e777be0SXin Li                            const bool exportEnabled = false);
61*3e777be0SXin Li 
62*3e777be0SXin Li     ArmnnPreparedModel_1_3(armnn::NetworkId networkId,
63*3e777be0SXin Li                            armnn::IRuntime* runtime,
64*3e777be0SXin Li                            const std::string& requestInputsAndOutputsDumpDir,
65*3e777be0SXin Li                            const bool gpuProfilingEnabled,
66*3e777be0SXin Li                            V1_3::Priority priority = V1_3::Priority::MEDIUM,
67*3e777be0SXin Li                            const bool asyncModelExecutionEnabled = false,
68*3e777be0SXin Li                            const unsigned int numberOfThreads = 1,
69*3e777be0SXin Li                            const bool importEnabled = false,
70*3e777be0SXin Li                            const bool exportEnabled = false,
71*3e777be0SXin Li                            const bool preparedFromCache = false);
72*3e777be0SXin Li 
73*3e777be0SXin Li     virtual ~ArmnnPreparedModel_1_3();
74*3e777be0SXin Li 
75*3e777be0SXin Li     Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
76*3e777be0SXin Li                                       const ::android::sp<V1_0::IExecutionCallback>& callback) override;
77*3e777be0SXin Li 
78*3e777be0SXin Li     Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure,
79*3e777be0SXin Li                                           const ::android::sp<V1_2::IExecutionCallback>& callback) override;
80*3e777be0SXin Li 
81*3e777be0SXin Li     Return<V1_3::ErrorStatus> execute_1_3(const V1_3::Request& request,
82*3e777be0SXin Li                                           V1_2::MeasureTiming measure,
83*3e777be0SXin Li                                           const V1_3::OptionalTimePoint&,
84*3e777be0SXin Li                                           const V1_3::OptionalTimeoutDuration&,
85*3e777be0SXin Li                                           const ::android::sp<V1_3::IExecutionCallback>& callback) override;
86*3e777be0SXin Li 
87*3e777be0SXin Li     Return<void> executeSynchronously(const V1_0::Request &request,
88*3e777be0SXin Li                                       V1_2::MeasureTiming measure,
89*3e777be0SXin Li                                       V1_3::IPreparedModel::executeSynchronously_cb cb) override;
90*3e777be0SXin Li 
91*3e777be0SXin Li     Return<void> executeSynchronously_1_3(const V1_3::Request &request,
92*3e777be0SXin Li                                           V1_2::MeasureTiming measure,
93*3e777be0SXin Li                                           const V1_3::OptionalTimePoint& deadline,
94*3e777be0SXin Li                                           const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
95*3e777be0SXin Li                                           V1_3::IPreparedModel::executeSynchronously_1_3_cb cb) override;
96*3e777be0SXin Li 
97*3e777be0SXin Li     Return<void> executeFenced(const V1_3::Request& request,
98*3e777be0SXin Li                                const android::hardware::hidl_vec<android::hardware::hidl_handle>& fenceWaitFor,
99*3e777be0SXin Li                                V1_2::MeasureTiming measure,
100*3e777be0SXin Li                                const V1_3::OptionalTimePoint& deadline,
101*3e777be0SXin Li                                const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
102*3e777be0SXin Li                                const V1_3::OptionalTimeoutDuration& duration,
103*3e777be0SXin Li                                executeFenced_cb callback) override;
104*3e777be0SXin Li 
105*3e777be0SXin Li     Return<void> configureExecutionBurst(
106*3e777be0SXin Li             const ::android::sp<V1_2::IBurstCallback>& callback,
107*3e777be0SXin Li             const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
108*3e777be0SXin Li             const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
109*3e777be0SXin Li             configureExecutionBurst_cb cb) override;
110*3e777be0SXin Li 
111*3e777be0SXin Li     template<typename CallbackContext>
112*3e777be0SXin Li     Return<void> ExecuteSynchronously(const V1_3::Request& request, CallbackContext cbCtx);
113*3e777be0SXin Li 
114*3e777be0SXin Li     /// execute the graph prepared from the request
115*3e777be0SXin Li     template<typename CallbackContext>
116*3e777be0SXin Li     Return <V1_3::ErrorStatus> ExecuteGraph(
117*3e777be0SXin Li               std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
118*3e777be0SXin Li               armnn::InputTensors& inputTensors,
119*3e777be0SXin Li               armnn::OutputTensors& outputTensors,
120*3e777be0SXin Li               CallbackContext callback);
121*3e777be0SXin Li 
122*3e777be0SXin Li     /// Executes this model with dummy inputs (e.g. all zeroes).
123*3e777be0SXin Li     /// \return false on failure, otherwise true
124*3e777be0SXin Li     bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs);
125*3e777be0SXin Li 
126*3e777be0SXin Li     V1_3::Priority GetModelPriority();
127*3e777be0SXin Li 
128*3e777be0SXin Li private:
129*3e777be0SXin Li 
130*3e777be0SXin Li     template<typename CallbackContext>
131*3e777be0SXin Li     class ArmnnThreadPoolCallback_1_3 : public armnn::IAsyncExecutionCallback
132*3e777be0SXin Li     {
133*3e777be0SXin Li     public:
ArmnnThreadPoolCallback_1_3(ArmnnPreparedModel_1_3<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::vector<V1_2::OutputShape> outputShapes,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)134*3e777be0SXin Li         ArmnnThreadPoolCallback_1_3(ArmnnPreparedModel_1_3<HalVersion>* model,
135*3e777be0SXin Li                                     std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
136*3e777be0SXin Li                                     std::vector<V1_2::OutputShape> outputShapes,
137*3e777be0SXin Li                                     std::shared_ptr<armnn::InputTensors>& inputTensors,
138*3e777be0SXin Li                                     std::shared_ptr<armnn::OutputTensors>& outputTensors,
139*3e777be0SXin Li                                     CallbackContext callbackContext) :
140*3e777be0SXin Li                 m_Model(model),
141*3e777be0SXin Li                 m_MemPools(pMemPools),
142*3e777be0SXin Li                 m_OutputShapes(outputShapes),
143*3e777be0SXin Li                 m_InputTensors(inputTensors),
144*3e777be0SXin Li                 m_OutputTensors(outputTensors),
145*3e777be0SXin Li                 m_CallbackContext(callbackContext)
146*3e777be0SXin Li         {}
147*3e777be0SXin Li 
148*3e777be0SXin Li         void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
149*3e777be0SXin Li 
150*3e777be0SXin Li         ArmnnPreparedModel_1_3<HalVersion>* m_Model;
151*3e777be0SXin Li         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
152*3e777be0SXin Li         std::vector<V1_2::OutputShape> m_OutputShapes;
153*3e777be0SXin Li         std::shared_ptr<armnn::InputTensors> m_InputTensors;
154*3e777be0SXin Li         std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
155*3e777be0SXin Li         CallbackContext m_CallbackContext;
156*3e777be0SXin Li     };
157*3e777be0SXin Li 
158*3e777be0SXin Li     Return <V1_3::ErrorStatus> Execute(const V1_3::Request& request,
159*3e777be0SXin Li                                        V1_2::MeasureTiming measureTiming,
160*3e777be0SXin Li                                        CallbackAsync_1_3 callback);
161*3e777be0SXin Li 
162*3e777be0SXin Li     Return<V1_3::ErrorStatus> PrepareMemoryForInputs(
163*3e777be0SXin Li         armnn::InputTensors& inputs,
164*3e777be0SXin Li         const V1_3::Request& request,
165*3e777be0SXin Li         const std::vector<android::nn::RunTimePoolInfo>& memPools);
166*3e777be0SXin Li 
167*3e777be0SXin Li     Return<V1_3::ErrorStatus> PrepareMemoryForOutputs(
168*3e777be0SXin Li         armnn::OutputTensors& outputs,
169*3e777be0SXin Li         std::vector<V1_2::OutputShape> &outputShapes,
170*3e777be0SXin Li         const V1_3::Request& request,
171*3e777be0SXin Li         const std::vector<android::nn::RunTimePoolInfo>& memPools);
172*3e777be0SXin Li 
173*3e777be0SXin Li     std::tuple<V1_3::ErrorStatus, hidl_vec<V1_2::OutputShape>, V1_2::Timing, std::string> PrepareMemoryForIO(
174*3e777be0SXin Li         armnn::InputTensors& inputs,
175*3e777be0SXin Li         armnn::OutputTensors& outputs,
176*3e777be0SXin Li         std::vector<android::nn::RunTimePoolInfo>& memPools,
177*3e777be0SXin Li         const V1_3::Request& request);
178*3e777be0SXin Li 
179*3e777be0SXin Li     template <typename TensorBindingCollection>
180*3e777be0SXin Li     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
181*3e777be0SXin Li 
182*3e777be0SXin Li     /// schedule the graph prepared from the request for execution
183*3e777be0SXin Li     template<typename CallbackContext>
184*3e777be0SXin Li     void ScheduleGraphForExecution(
185*3e777be0SXin Li             std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
186*3e777be0SXin Li             std::shared_ptr<armnn::InputTensors>& inputTensors,
187*3e777be0SXin Li             std::shared_ptr<armnn::OutputTensors>& outputTensors,
188*3e777be0SXin Li             CallbackContext m_CallbackContext,
189*3e777be0SXin Li             armnn::QosExecPriority priority);
190*3e777be0SXin Li 
191*3e777be0SXin Li     armnn::NetworkId                               m_NetworkId;
192*3e777be0SXin Li     armnn::IRuntime*                               m_Runtime;
193*3e777be0SXin Li     V1_3::Model                                    m_Model;
194*3e777be0SXin Li     // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
195*3e777be0SXin Li     // It is specific to this class, so it is declared as static here
196*3e777be0SXin Li     static RequestThread_1_3<ArmnnPreparedModel_1_3,
197*3e777be0SXin Li                              HalVersion,
198*3e777be0SXin Li                              CallbackContext_1_3>  m_RequestThread;
199*3e777be0SXin Li     uint32_t                                       m_RequestCount;
200*3e777be0SXin Li     const std::string&                             m_RequestInputsAndOutputsDumpDir;
201*3e777be0SXin Li     const bool                                     m_GpuProfilingEnabled;
202*3e777be0SXin Li     V1_3::Priority                                 m_ModelPriority;
203*3e777be0SXin Li 
204*3e777be0SXin Li     // Static to allow sharing of threadpool between ArmnnPreparedModel instances
205*3e777be0SXin Li     static std::unique_ptr<armnn::Threadpool>      m_Threadpool;
206*3e777be0SXin Li     std::shared_ptr<IWorkingMemHandle>             m_WorkingMemHandle;
207*3e777be0SXin Li     const bool                                     m_AsyncModelExecutionEnabled;
208*3e777be0SXin Li     const bool                                     m_EnableImport;
209*3e777be0SXin Li     const bool                                     m_EnableExport;
210*3e777be0SXin Li     const bool                                     m_PreparedFromCache;
211*3e777be0SXin Li };
212*3e777be0SXin Li 
213*3e777be0SXin Li }
214