xref: /aosp_15_r20/external/armnn/include/armnn/Threadpool.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021-2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #if !defined(ARMNN_DISABLE_THREADS)
6 
7 #pragma once
8 
9 #include "IRuntime.hpp"
10 #include <armnn/Tensor.hpp>
11 #include <armnn/Types.hpp>
12 #include <stdint.h>
13 #include <thread>
14 #include <mutex>
15 #include <condition_variable>
16 #include <unordered_map>
17 #include <queue>
18 #include <iosfwd>
19 #include <memory>
20 #include <tuple>
21 #include <vector>
22 
23 namespace armnn
24 {
25 namespace experimental
26 {
27 class IAsyncExecutionCallback;
28 class IWorkingMemHandle;
29 
30 class Threadpool
31 {
32 public:
33     Threadpool(std::size_t numThreads,
34                IRuntime* runtimePtr,
35                std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);
36 
~Threadpool()37     ~Threadpool()
38     {
39         TerminateThreadPool();
40     }
41 
42     void LoadMemHandles(std::vector<std::shared_ptr<IWorkingMemHandle>> memHandles);
43     void UnloadMemHandles(NetworkId networkId);
44 
45     /// Schedule an asynchronous execution on the loaded network
46     void Schedule(NetworkId networkId,
47                   const InputTensors &inputTensors,
48                   const OutputTensors &outputTensors,
49                   const QosExecPriority priority,
50                   std::shared_ptr<IAsyncExecutionCallback> cb);
51 
52     void TerminateThreadPool() noexcept;
53 
54 private:
55     using ExecutionTuple = std::tuple<NetworkId,
56                                       InputTensors,
57                                       OutputTensors,
58                                       std::shared_ptr<IAsyncExecutionCallback>>;
59 
60     using ExecutionQueue = std::queue<std::shared_ptr<ExecutionTuple>>;
61 
62     void ProcessExecPriorities(uint32_t index);
63 
64     IRuntime* m_RuntimePtr;
65 
66     ExecutionQueue m_HighPriorityQueue;
67     ExecutionQueue m_MediumPriorityQueue;
68     ExecutionQueue m_LowPriorityQueue;
69 
70     // Condition Variables require mutex which will guard the shared state.
71     // Has an event happened? Stop signal for example
72     std::condition_variable m_ThreadPoolEvent;
73     std::mutex m_ThreadPoolMutex;
74 
75     // The shared state for conditional variable
76     bool m_TerminatePool = false;
77 
78     std::unordered_map<NetworkId, std::vector<std::shared_ptr<IWorkingMemHandle>>> m_WorkingMemHandleMap;
79     std::vector<std::unique_ptr<std::thread>> m_Threads;
80 };
81 
82 } // namespace experimental
83 
84 } // namespace armnn
85 
86 #endif
87