xref: /aosp_15_r20/external/armnn/src/armnn/Threadpool.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #if !defined(ARMNN_DISABLE_THREADS)
6 
7 #include <armnn/Threadpool.hpp>
8 
9 #include <armnn/utility/Timer.hpp>
10 
11 namespace armnn
12 {
13 namespace experimental
14 {
15 
Threadpool(std::size_t numThreads,IRuntime * runtimePtr,std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)16 Threadpool::Threadpool(std::size_t numThreads,
17                        IRuntime* runtimePtr,
18                        std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
19     : m_RuntimePtr(runtimePtr)
20 {
21     for (auto i = 0u; i < numThreads; ++i)
22     {
23         m_Threads.emplace_back(std::make_unique<std::thread>(&Threadpool::ProcessExecPriorities, this, i));
24     }
25 
26     LoadMemHandles(memHandles);
27 }
28 
LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)29 void Threadpool::LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles)
30 {
31     if (memHandles.size() == 0)
32     {
33         throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Size of memHandles vector must be greater than 0");
34     }
35 
36     if (memHandles.size() != m_Threads.size())
37     {
38         throw armnn::RuntimeException(
39                 "Threadpool::UnloadMemHandles: Size of memHandles vector must match the number of threads");
40     }
41 
42     NetworkId networkId = memHandles[0]->GetNetworkId();
43     for (uint32_t i = 1; i < memHandles.size(); ++i)
44     {
45         if (networkId != memHandles[i]->GetNetworkId())
46         {
47             throw armnn::RuntimeException(
48                     "Threadpool::UnloadMemHandles: All network ids must be identical in memHandles");
49         }
50     }
51 
52     std::pair<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> pair {networkId, memHandles};
53 
54     m_WorkingMemHandleMap.insert(pair);
55 }
56 
UnloadMemHandles(NetworkId networkId)57 void Threadpool::UnloadMemHandles(NetworkId networkId)
58 {
59     if (m_WorkingMemHandleMap.find(networkId) != m_WorkingMemHandleMap.end())
60     {
61         m_WorkingMemHandleMap.erase(networkId);
62     }
63     else
64     {
65        throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
66     }
67 }
68 
Schedule(NetworkId networkId,const InputTensors & inputTensors,const OutputTensors & outputTensors,const QosExecPriority priority,std::shared_ptr<IAsyncExecutionCallback> cb)69 void Threadpool::Schedule(NetworkId networkId,
70                           const InputTensors& inputTensors,
71                           const OutputTensors& outputTensors,
72                           const QosExecPriority priority,
73                           std::shared_ptr<IAsyncExecutionCallback> cb)
74 {
75     if (m_WorkingMemHandleMap.find(networkId) == m_WorkingMemHandleMap.end())
76     {
77         throw armnn::RuntimeException("Threadpool::UnloadMemHandles: Unknown NetworkId");
78     }
79 
80     // Group execution parameters so that they can be easily added to the queue
81     ExecutionTuple groupExecParams = std::make_tuple(networkId, inputTensors, outputTensors, cb);
82 
83     std::shared_ptr<ExecutionTuple> operation = std::make_shared<ExecutionTuple>(groupExecParams);
84 
85     // Add a message to the queue and notify the request thread
86     std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
87     switch (priority)
88     {
89         case QosExecPriority::High:
90             m_HighPriorityQueue.push(operation);
91             break;
92         case QosExecPriority::Low:
93             m_LowPriorityQueue.push(operation);
94             break;
95         case QosExecPriority::Medium:
96         default:
97             m_MediumPriorityQueue.push(operation);
98     }
99     m_ThreadPoolEvent.notify_one();
100 }
101 
TerminateThreadPool()102 void Threadpool::TerminateThreadPool() noexcept
103 {
104     {
105         std::unique_lock<std::mutex> threadPoolLock(m_ThreadPoolMutex);
106         m_TerminatePool = true;
107     }
108 
109     m_ThreadPoolEvent.notify_all();
110 
111     for (auto &thread : m_Threads)
112     {
113         thread->join();
114     }
115 }
116 
ProcessExecPriorities(uint32_t index)117 void Threadpool::ProcessExecPriorities(uint32_t index)
118 {
119     int expireRate = EXPIRE_RATE;
120     int highPriorityCount = 0;
121     int mediumPriorityCount = 0;
122 
123     while (true)
124     {
125         std::shared_ptr<ExecutionTuple> currentExecInProgress(nullptr);
126         {
127             // Wait for a message to be added to the queue
128             // This is in a separate scope to minimise the lifetime of the lock
129             std::unique_lock<std::mutex> lock(m_ThreadPoolMutex);
130 
131             m_ThreadPoolEvent.wait(lock,
132                                    [=]
133                                    {
134                                        return m_TerminatePool || !m_HighPriorityQueue.empty() ||
135                                               !m_MediumPriorityQueue.empty() || !m_LowPriorityQueue.empty();
136                                    });
137 
138             if (m_TerminatePool && m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() &&
139                 m_LowPriorityQueue.empty())
140             {
141                 break;
142             }
143 
144             // Get the message to process from the front of each queue based on priority from high to low
145             // Get high priority first if it does not exceed the expire rate
146             if (!m_HighPriorityQueue.empty() && highPriorityCount < expireRate)
147             {
148                 currentExecInProgress = m_HighPriorityQueue.front();
149                 m_HighPriorityQueue.pop();
150                 highPriorityCount += 1;
151             }
152                 // If high priority queue is empty or the count exceeds the expire rate, get medium priority message
153             else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < expireRate)
154             {
155                 currentExecInProgress = m_MediumPriorityQueue.front();
156                 m_MediumPriorityQueue.pop();
157                 mediumPriorityCount += 1;
158                 // Reset high priority count
159                 highPriorityCount = 0;
160             }
161                 // If medium priority queue is empty or the count exceeds the expire rate, get low priority message
162             else if (!m_LowPriorityQueue.empty())
163             {
164                 currentExecInProgress = m_LowPriorityQueue.front();
165                 m_LowPriorityQueue.pop();
166                 // Reset high and medium priority count
167                 highPriorityCount = 0;
168                 mediumPriorityCount = 0;
169             }
170             else
171             {
172                 // Reset high and medium priority count
173                 highPriorityCount = 0;
174                 mediumPriorityCount = 0;
175                 continue;
176             }
177         }
178 
179         // invoke the asynchronous execution method
180         auto networkId = std::get<0>(*currentExecInProgress);
181         auto inputTensors = std::get<1>(*currentExecInProgress);
182         auto outputTensors = std::get<2>(*currentExecInProgress);
183         auto cb = std::get<3>(*currentExecInProgress);
184 
185         // Get time at start of inference
186         HighResolutionClock startTime = armnn::GetTimeNow();
187 
188         try // executing the inference
189         {
190             IWorkingMemHandle& memHandle = *(m_WorkingMemHandleMap.at(networkId))[index];
191 
192             // Execute and populate the time at end of inference in the callback
193             m_RuntimePtr->Execute(memHandle, inputTensors, outputTensors) == Status::Success ?
194             cb->Notify(Status::Success, std::make_pair(startTime, armnn::GetTimeNow())) :
195             cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
196         }
197         catch (const RuntimeException&)
198         {
199             cb->Notify(Status::Failure, std::make_pair(startTime, armnn::GetTimeNow()));
200         }
201     }
202 }
203 
204 } // namespace experimental
205 
206 } // namespace armnn
207 
208 #endif
209