xref: /aosp_15_r20/external/pytorch/c10/core/thread_pool.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <atomic>
4 #include <condition_variable>
5 #include <cstddef>
6 #include <functional>
7 #include <mutex>
8 #include <queue>
9 #include <thread>
10 #include <utility>
11 #include <vector>
12 
13 #include <c10/macros/Export.h>
14 #include <c10/util/Registry.h>
15 #include <c10/util/numa.h>
16 #include <c10/util/thread_name.h>
17 
18 namespace c10 {
19 
20 class C10_API TaskThreadPoolBase {
21  public:
22   virtual void run(std::function<void()> func) = 0;
23 
24   virtual size_t size() const = 0;
25 
26   /**
27    * The number of available (i.e. idle) threads in this thread pool.
28    */
29   virtual size_t numAvailable() const = 0;
30 
31   /**
32    * Check if the current thread is from the thread pool.
33    */
34   virtual bool inThreadPool() const = 0;
35 
36   virtual ~TaskThreadPoolBase() noexcept = default;
37 
38   static size_t defaultNumThreads();
39 };
40 
41 class C10_API ThreadPool : public c10::TaskThreadPoolBase {
42  protected:
43   struct task_element_t {
44     bool run_with_id;
45     // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
46     const std::function<void()> no_id;
47     // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
48     const std::function<void(std::size_t)> with_id;
49 
task_element_ttask_element_t50     explicit task_element_t(std::function<void()> f)
51         : run_with_id(false), no_id(std::move(f)), with_id(nullptr) {}
task_element_ttask_element_t52     explicit task_element_t(std::function<void(std::size_t)> f)
53         : run_with_id(true), no_id(nullptr), with_id(std::move(f)) {}
54   };
55 
56   std::queue<task_element_t> tasks_;
57   std::vector<std::thread> threads_;
58   mutable std::mutex mutex_;
59   std::condition_variable condition_;
60   std::condition_variable completed_;
61   std::atomic_bool running_;
62   bool complete_;
63   std::size_t available_;
64   std::size_t total_;
65   int numa_node_id_;
66 
67  public:
68   ThreadPool() = delete;
69 
70   explicit ThreadPool(
71       int pool_size,
72       int numa_node_id = -1,
73       const std::function<void()>& init_thread = nullptr);
74 
75   ~ThreadPool() override;
76 
77   size_t size() const override;
78 
79   size_t numAvailable() const override;
80 
81   bool inThreadPool() const override;
82 
83   void run(std::function<void()> func) override;
84 
85   template <typename Task>
runTaskWithID(Task task)86   void runTaskWithID(Task task) {
87     std::unique_lock<std::mutex> lock(mutex_);
88 
89     // Set task and signal condition variable so that a worker thread will
90     // wake up and use the task.
91     tasks_.emplace(static_cast<std::function<void(std::size_t)>>(task));
92     complete_ = false;
93     condition_.notify_one();
94   }
95 
96   /// @brief Wait for queue to be empty
97   void waitWorkComplete();
98 
99  private:
100   // @brief Entry point for pool threads.
101   void main_loop(std::size_t index);
102 };
103 
104 class C10_API TaskThreadPool : public c10::ThreadPool {
105  public:
106   explicit TaskThreadPool(int pool_size, int numa_node_id = -1)
107       : ThreadPool(pool_size, numa_node_id, [numa_node_id]() {
108           setThreadName("CaffeTaskThread");
109           NUMABind(numa_node_id);
110         }) {}
111 };
112 
113 C10_DECLARE_SHARED_REGISTRY(
114     ThreadPoolRegistry,
115     TaskThreadPoolBase,
116     int,
117     int,
118     bool);
119 
120 } // namespace c10
121