1 #pragma once 2 3 #include <atomic> 4 #include <condition_variable> 5 #include <cstddef> 6 #include <functional> 7 #include <mutex> 8 #include <queue> 9 #include <thread> 10 #include <utility> 11 #include <vector> 12 13 #include <c10/macros/Export.h> 14 #include <c10/util/Registry.h> 15 #include <c10/util/numa.h> 16 #include <c10/util/thread_name.h> 17 18 namespace c10 { 19 20 class C10_API TaskThreadPoolBase { 21 public: 22 virtual void run(std::function<void()> func) = 0; 23 24 virtual size_t size() const = 0; 25 26 /** 27 * The number of available (i.e. idle) threads in this thread pool. 28 */ 29 virtual size_t numAvailable() const = 0; 30 31 /** 32 * Check if the current thread is from the thread pool. 33 */ 34 virtual bool inThreadPool() const = 0; 35 36 virtual ~TaskThreadPoolBase() noexcept = default; 37 38 static size_t defaultNumThreads(); 39 }; 40 41 class C10_API ThreadPool : public c10::TaskThreadPoolBase { 42 protected: 43 struct task_element_t { 44 bool run_with_id; 45 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 46 const std::function<void()> no_id; 47 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 48 const std::function<void(std::size_t)> with_id; 49 task_element_ttask_element_t50 explicit task_element_t(std::function<void()> f) 51 : run_with_id(false), no_id(std::move(f)), with_id(nullptr) {} task_element_ttask_element_t52 explicit task_element_t(std::function<void(std::size_t)> f) 53 : run_with_id(true), no_id(nullptr), with_id(std::move(f)) {} 54 }; 55 56 std::queue<task_element_t> tasks_; 57 std::vector<std::thread> threads_; 58 mutable std::mutex mutex_; 59 std::condition_variable condition_; 60 std::condition_variable completed_; 61 std::atomic_bool running_; 62 bool complete_; 63 std::size_t available_; 64 std::size_t total_; 65 int numa_node_id_; 66 67 public: 68 ThreadPool() = delete; 69 70 explicit ThreadPool( 71 int pool_size, 72 int numa_node_id = -1, 73 const std::function<void()>& init_thread = nullptr); 74 75 ~ThreadPool() override; 76 77 size_t size() const override; 78 79 size_t numAvailable() const override; 80 81 bool inThreadPool() const override; 82 83 void run(std::function<void()> func) override; 84 85 template <typename Task> runTaskWithID(Task task)86 void runTaskWithID(Task task) { 87 std::unique_lock<std::mutex> lock(mutex_); 88 89 // Set task and signal condition variable so that a worker thread will 90 // wake up and use the task. 91 tasks_.emplace(static_cast<std::function<void(std::size_t)>>(task)); 92 complete_ = false; 93 condition_.notify_one(); 94 } 95 96 /// @brief Wait for queue to be empty 97 void waitWorkComplete(); 98 99 private: 100 // @brief Entry point for pool threads. 101 void main_loop(std::size_t index); 102 }; 103 104 class C10_API TaskThreadPool : public c10::ThreadPool { 105 public: 106 explicit TaskThreadPool(int pool_size, int numa_node_id = -1) 107 : ThreadPool(pool_size, numa_node_id, [numa_node_id]() { 108 setThreadName("CaffeTaskThread"); 109 NUMABind(numa_node_id); 110 }) {} 111 }; 112 113 C10_DECLARE_SHARED_REGISTRY( 114 ThreadPoolRegistry, 115 TaskThreadPoolBase, 116 int, 117 int, 118 bool); 119 120 } // namespace c10 121