xref: /aosp_15_r20/external/armnn/src/armnn/AsyncExecutionCallback.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <AsyncExecutionCallback.hpp>
7 
8 namespace armnn
9 {
10 
11 namespace experimental
12 {
13 
14 InferenceId AsyncExecutionCallback::nextID = 0u;
15 
Notify(armnn::Status status,InferenceTimingPair timeTaken)16 void AsyncExecutionCallback::Notify(armnn::Status status, InferenceTimingPair timeTaken)
17 {
18     {
19 #if !defined(ARMNN_DISABLE_THREADS)
20         std::lock_guard<std::mutex> hold(m_Mutex);
21 #endif
22         // store results and mark as notified
23         m_Status    = status;
24         m_StartTime = timeTaken.first;
25         m_EndTime   = timeTaken.second;
26         m_NotificationQueue.push(m_InferenceId);
27     }
28 #if !defined(ARMNN_DISABLE_THREADS)
29     m_Condition.notify_all();
30 #endif
31 }
32 
GetStatus() const33 armnn::Status AsyncExecutionCallback::GetStatus() const
34 {
35     return m_Status;
36 }
37 
GetStartTime() const38 HighResolutionClock AsyncExecutionCallback::GetStartTime() const
39 {
40     return m_StartTime;
41 }
42 
GetEndTime() const43 HighResolutionClock AsyncExecutionCallback::GetEndTime() const
44 {
45     return m_EndTime;
46 }
47 
GetNewCallback()48 std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNewCallback()
49 {
50     auto cb = std::make_unique<AsyncExecutionCallback>(m_NotificationQueue
51 #if !defined(ARMNN_DISABLE_THREADS)
52                                                        , m_Mutex
53                                                        , m_Condition
54 #endif
55         );
56     InferenceId id = cb->GetInferenceId();
57     m_Callbacks.insert({id, std::move(cb)});
58 
59     return m_Callbacks.at(id);
60 }
61 
GetNotifiedCallback()62 std::shared_ptr<AsyncExecutionCallback> AsyncCallbackManager::GetNotifiedCallback()
63 {
64 #if !defined(ARMNN_DISABLE_THREADS)
65     std::unique_lock<std::mutex> lock(m_Mutex);
66 
67     m_Condition.wait(lock, [this] { return !m_NotificationQueue.empty(); });
68 #endif
69     InferenceId id = m_NotificationQueue.front();
70     m_NotificationQueue.pop();
71 
72     std::shared_ptr<AsyncExecutionCallback> callback = m_Callbacks.at(id);
73     m_Callbacks.erase(id);
74     return callback;
75 }
76 
77 } // namespace experimental
78 
79 } // namespace armnn
80