xref: /aosp_15_r20/external/android-nn-driver/ArmnnPreparedModel.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ArmnnDriver.hpp"
9 #include "ArmnnDriverImpl.hpp"
10 #include "RequestThread.hpp"
11 
12 #include <NeuralNetworks.h>
13 #include <armnn/ArmNN.hpp>
14 #include <armnn/Threadpool.hpp>
15 
16 #include <string>
17 #include <vector>
18 
19 namespace armnn_driver
20 {
21 using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>;
22 
23 struct ArmnnCallback_1_0
24 {
25     armnnExecuteCallback_1_0 callback;
26 };
27 
28 struct ExecutionContext_1_0 {};
29 
30 using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>;
31 
32 template <typename HalVersion>
33 class ArmnnPreparedModel : public V1_0::IPreparedModel
34 {
35 public:
36     using HalModel = typename HalVersion::Model;
37 
38     ArmnnPreparedModel(armnn::NetworkId networkId,
39                        armnn::IRuntime* runtime,
40                        const HalModel& model,
41                        const std::string& requestInputsAndOutputsDumpDir,
42                        const bool gpuProfilingEnabled,
43                        const bool asyncModelExecutionEnabled = false,
44                        const unsigned int numberOfThreads = 1,
45                        const bool importEnabled = false,
46                        const bool exportEnabled = false);
47 
48     virtual ~ArmnnPreparedModel();
49 
50     virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
51                                               const ::android::sp<V1_0::IExecutionCallback>& callback) override;
52 
53     /// execute the graph prepared from the request
54     void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
55                       armnn::InputTensors& inputTensors,
56                       armnn::OutputTensors& outputTensors,
57                       CallbackContext_1_0 callback);
58 
59     /// Executes this model with dummy inputs (e.g. all zeroes).
60     /// \return false on failure, otherwise true
61     bool ExecuteWithDummyInputs();
62 
63 private:
64 
65     template<typename CallbackContext>
66     class ArmnnThreadPoolCallback : public armnn::IAsyncExecutionCallback
67     {
68     public:
ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)69         ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion>* model,
70                                 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
71                                 std::shared_ptr<armnn::InputTensors>& inputTensors,
72                                 std::shared_ptr<armnn::OutputTensors>& outputTensors,
73                                 CallbackContext callbackContext) :
74                 m_Model(model),
75                 m_MemPools(pMemPools),
76                 m_InputTensors(inputTensors),
77                 m_OutputTensors(outputTensors),
78                 m_CallbackContext(callbackContext)
79         {}
80 
81         void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
82 
83         ArmnnPreparedModel<HalVersion>* m_Model;
84         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
85         std::shared_ptr<armnn::InputTensors> m_InputTensors;
86         std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
87         CallbackContext m_CallbackContext;
88     };
89 
90     template <typename TensorBindingCollection>
91     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
92 
93     /// schedule the graph prepared from the request for execution
94     template<typename CallbackContext>
95     void ScheduleGraphForExecution(
96             std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
97             std::shared_ptr<armnn::InputTensors>& inputTensors,
98             std::shared_ptr<armnn::OutputTensors>& outputTensors,
99             CallbackContext m_CallbackContext);
100 
101     armnn::NetworkId                          m_NetworkId;
102     armnn::IRuntime*                          m_Runtime;
103     HalModel                                  m_Model;
104     // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
105     // It is specific to this class, so it is declared as static here
106     static RequestThread<ArmnnPreparedModel,
107                          HalVersion,
108                          CallbackContext_1_0> m_RequestThread;
109     uint32_t                                  m_RequestCount;
110     const std::string&                        m_RequestInputsAndOutputsDumpDir;
111     const bool                                m_GpuProfilingEnabled;
112     // Static to allow sharing of threadpool between ArmnnPreparedModel instances
113     static std::unique_ptr<armnn::Threadpool> m_Threadpool;
114     std::shared_ptr<armnn::IWorkingMemHandle> m_WorkingMemHandle;
115     const bool m_AsyncModelExecutionEnabled;
116     const bool m_EnableImport;
117     const bool m_EnableExport;
118 };
119 
120 }
121