xref: /aosp_15_r20/external/pytorch/c10/core/thread_pool.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/thread_pool.h>
2 #include <c10/util/Logging.h>
3 #include <c10/util/thread_name.h>
4 #if !defined(__powerpc__) && !defined(__s390x__)
5 #include <cpuinfo.h>
6 #endif
7 
8 namespace c10 {
9 
defaultNumThreads()10 size_t TaskThreadPoolBase::defaultNumThreads() {
11   size_t num_threads = 0;
12 #if !defined(__powerpc__) && !defined(__s390x__)
13   if (cpuinfo_initialize()) {
14     // In cpuinfo parlance cores are physical ones and processors are virtual
15     // ThreadPool should be defaulted to number of physical cores
16     size_t num_cores = cpuinfo_get_cores_count();
17     num_threads = cpuinfo_get_processors_count();
18     if (num_cores > 0 && num_cores < num_threads) {
19       return num_cores;
20     }
21     if (num_threads > 0) {
22       return num_threads;
23     }
24   }
25 #endif
26   num_threads = std::thread::hardware_concurrency();
27   if (num_threads == 0) {
28     num_threads = 1;
29   }
30   return num_threads;
31 }
32 
ThreadPool(int pool_size,int numa_node_id,const std::function<void ()> & init_thread)33 ThreadPool::ThreadPool(
34     int pool_size,
35     int numa_node_id,
36     const std::function<void()>& init_thread)
37     : threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
38       running_(true),
39       complete_(true),
40       available_(threads_.size()),
41       total_(threads_.size()),
42       numa_node_id_(numa_node_id) {
43   for (std::size_t i = 0; i < threads_.size(); ++i) {
44     threads_[i] = std::thread([this, i, init_thread]() {
45       c10::setThreadName("pt_thread_pool");
46       if (init_thread) {
47         init_thread();
48       }
49       this->main_loop(i);
50     });
51   }
52 }
53 
~ThreadPool()54 ThreadPool::~ThreadPool() {
55   // Set running flag to false then notify all threads.
56   {
57     std::unique_lock<std::mutex> lock(mutex_);
58     running_ = false;
59     condition_.notify_all();
60   }
61 
62   for (auto& t : threads_) {
63     try {
64       t.join();
65     } catch (const std::exception&) {
66     }
67   }
68 }
69 
size() const70 size_t ThreadPool::size() const {
71   return threads_.size();
72 }
73 
numAvailable() const74 size_t ThreadPool::numAvailable() const {
75   std::unique_lock<std::mutex> lock(mutex_);
76   return available_;
77 }
78 
inThreadPool() const79 bool ThreadPool::inThreadPool() const {
80   for (auto& thread : threads_) {
81     if (thread.get_id() == std::this_thread::get_id()) {
82       return true;
83     }
84   }
85   return false;
86 }
87 
run(std::function<void ()> func)88 void ThreadPool::run(std::function<void()> func) {
89   if (threads_.empty()) {
90     throw std::runtime_error("No threads to run a task");
91   }
92   std::unique_lock<std::mutex> lock(mutex_);
93 
94   // Set task and signal condition variable so that a worker thread will
95   // wake up and use the task.
96   tasks_.emplace(std::move(func));
97   complete_ = false;
98   condition_.notify_one();
99 }
100 
waitWorkComplete()101 void ThreadPool::waitWorkComplete() {
102   std::unique_lock<std::mutex> lock(mutex_);
103   completed_.wait(lock, [&]() { return complete_; });
104 }
105 
main_loop(std::size_t index)106 void ThreadPool::main_loop(std::size_t index) {
107   std::unique_lock<std::mutex> lock(mutex_);
108   while (running_) {
109     // Wait on condition variable while the task is empty and
110     // the pool is still running.
111     condition_.wait(lock, [&]() { return !tasks_.empty() || !running_; });
112     // If pool is no longer running, break out of loop.
113     if (!running_) {
114       break;
115     }
116 
117     // Copy task locally and remove from the queue.  This is
118     // done within its own scope so that the task object is
119     // destructed immediately after running the task.  This is
120     // useful in the event that the function contains
121     // shared_ptr arguments bound via bind.
122     {
123       task_element_t tasks = std::move(tasks_.front());
124       tasks_.pop();
125       // Decrement count, indicating thread is no longer available.
126       --available_;
127 
128       lock.unlock();
129 
130       // Run the task.
131       try {
132         if (tasks.run_with_id) {
133           tasks.with_id(index);
134         } else {
135           tasks.no_id();
136         }
137       } catch (const std::exception& e) {
138         LOG(ERROR) << "Exception in thread pool task: " << e.what();
139       } catch (...) {
140         LOG(ERROR) << "Exception in thread pool task: unknown";
141       }
142 
143       // Destruct tasks before taking the lock.  As tasks
144       // are user provided std::function, they can run
145       // arbitrary code during destruction, including code
146       // that can reentrantly call into ThreadPool (which would
147       // cause a deadlock if we were holding the lock).
148     }
149 
150     // Update status of empty, maybe
151     // Need to recover the lock first
152     lock.lock();
153 
154     // Increment count, indicating thread is available.
155     ++available_;
156     if (tasks_.empty() && available_ == total_) {
157       complete_ = true;
158       completed_.notify_one();
159     }
160 
161     // Deliberately hold the lock on the backedge, so this thread has an
162     // opportunity to acquire a new task before another thread acquires
163     // the lock.
164   } // while running_
165 }
166 
167 C10_DEFINE_SHARED_REGISTRY(
168     ThreadPoolRegistry,
169     TaskThreadPoolBase,
170     int,
171     int,
172     bool);
173 } // namespace c10
174