1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #if !defined(ARMNN_DISABLE_THREADS)
6
7 #include <armnn/Threadpool.hpp>
8
9 #include <armnn/utility/Timer.hpp>
10
11 namespace armnn
12 {
13 namespace experimental
14 {
15
Threadpool(std::size_t numThreads,IRuntime * runtimePtr,std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)16 Threadpool::Threadpool(std::size_t numThreads,
17 IRuntime* runtimePtr,
18 std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
19 : m_RuntimePtr(runtimePtr)
20 {
21 for (auto i = 0u; i < numThreads; ++i)
22 {
23 m_Threads.emplace_back(std::make_unique<std::thread>(&Threadpool::ProcessExecPriorities, this, i));
24 }
25
26 LoadMemHandles(memHandles);
27 }
28
LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)29 void Threadpool::LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
30 {
31 if (memHandles.size() == 0)
32 {
33 throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Size of memHandles vector must be greater than 0");
34 }
35
36 if (memHandles.size() != m_Threads.size())
37 {
38 throw armnn::RuntimeException(
39 "Threadpool::UnloadMemHandles: Size of memHandles vector must match the number of threads");
40 }
41
42 NetworkId networkId = memHandles[0]->GetNetworkId();
43 for (uint32_t i = 1; i < memHandles.size(); ++i)
44 {
45 if (networkId != memHandles[i]->GetNetworkId())
46 {
47 throw armnn::RuntimeException(
48 "Threadpool::UnloadMemHandles: All network ids must be identical in memHandles");
49 }
50 }
51
52 std::pair<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> pair {networkId, memHandles};
53
54 m_WorkingMemHandleMap.insert(pair);
55 }
56
UnloadMemHandles(NetworkId networkId)57 void Threadpool::UnloadMemHandles(NetworkId networkId)
58 {
59 if (m_WorkingMemHandleMap.find(networkId) != m_WorkingMemHandleMap.end())
60 {
61 m_WorkingMemHandleMap.erase(networkId);
62 }
63 else
64 {
65 throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
66 }
67 }
68
Schedule(NetworkId networkId,const InputTensors & inputTensors,const OutputTensors & outputTensors,const QosExecPriority priority,std::shared_ptr<IAsyncExecutionCallback> cb)69 void Threadpool::Schedule(NetworkId networkId,
70 const InputTensors& inputTensors,
71 const OutputTensors& outputTensors,
72 const QosExecPriority priority,
73 std::shared_ptr<IAsyncExecutionCallback> cb)
74 {
75 if (m_WorkingMemHandleMap.find(networkId) == m_WorkingMemHandleMap.end())
76 {
77 throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
78 }
79
80 // Group execution parameters so that they can be easily added to the queue
81 ExecutionTuple groupExecParams = std::make_tuple(networkId, inputTensors, outputTensors, cb);
82
83 std::shared_ptr<ExecutionTuple> operation = std::make_shared<ExecutionTuple>(groupExecParams);
84
85 // Add a message to the queue and notify the request thread
86 std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
87 switch (priority)
88 {
89 case QosExecPriority::High:
90 m_HighPriorityQueue.push(operation);
91 break;
92 case QosExecPriority::Low:
93 m_LowPriorityQueue.push(operation);
94 break;
95 case QosExecPriority::Medium:
96 default:
97 m_MediumPriorityQueue.push(operation);
98 }
99 m_ThreadPoolEvent.notify_one();
100 }
101
TerminateThreadPool()102 void Threadpool::TerminateThreadPool() noexcept
103 {
104 {
105 std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
106 m_TerminatePool = true;
107 }
108
109 m_ThreadPoolEvent.notify_all();
110
111 for (auto &thread : m_Threads)
112 {
113 thread->join();
114 }
115 }
116
ProcessExecPriorities(uint32_t index)117 void Threadpool::ProcessExecPriorities(uint32_t index)
118 {
119 int expireRate = EXPIRE_RATE;
120 int highPriorityCount = 0;
121 int mediumPriorityCount = 0;
122
123 while (true)
124 {
125 std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
126 {
127 // Wait for a message to be added to the queue
128 // This is in a separate scope to minimise the lifetime of the lock
129 std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
130
131 m_ThreadPoolEvent.wait(lock,
132 [=]
133 {
134 return m_TerminatePool || !m_HighPriorityQueue.empty() ||
135 !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
136 });
137
138 if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
139 m_LowPriorityQueue.empty())
140 {
141 break;
142 }
143
144 // Get the message to process from the front of each queue based on priority from high to low
145 // Get high priority first if it does not exceed the expire rate
146 if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
147 {
148 currentExecInProgress = m_HighPriorityQueue.front();
149 m_HighPriorityQueue.pop();
150 highPriorityCount += 1;
151 }
152 // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
153 else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
154 {
155 currentExecInProgress = m_MediumPriorityQueue.front();
156 m_MediumPriorityQueue.pop();
157 mediumPriorityCount += 1;
158 // Reset high priority count
159 highPriorityCount = 0;
160 }
161 // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
162 else if (!m_LowPriorityQueue.empty())
163 {
164 currentExecInProgress = m_LowPriorityQueue.front();
165 m_LowPriorityQueue.pop();
166 // Reset high and medium priority count
167 highPriorityCount = 0;
168 mediumPriorityCount = 0;
169 }
170 else
171 {
172 // Reset high and medium priority count
173 highPriorityCount = 0;
174 mediumPriorityCount = 0;
175 continue;
176 }
177 }
178
179 // invoke the asynchronous execution method
180 auto networkId = std::get<0>(*currentExecInProgress);
181 auto inputTensors = std::get<1>(*currentExecInProgress);
182 auto outputTensors = std::get<2>(*currentExecInProgress);
183 auto cb = std::get<3>(*currentExecInProgress);
184
185 // Get time at start of inference
186 HighResolutionClock startTime = armnn::GetTimeNow();
187
188 try // executing the inference
189 {
190 IWorkingMemHandle& memHandle = *(m_WorkingMemHandleMap.at(networkId))[index];
191
192 // Execute and populate the time at end of inference in the callback
193 m_RuntimePtr->Execute(memHandle, inputTensors, outputTensors) == Status::Success ?
194 cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
195 cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
196 }
197 catch (const RuntimeException&)
198 {
199 cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
200 }
201 }
202 }
203
204 } // namespace experimental
205
206 } // namespace armnn
207
208 #endif
209