xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Dmitry Vyukov <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
11 #define EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
12 
13 namespace Eigen {
14 
15 template <typename Environment>
16 class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
17  public:
18   typedef typename Environment::Task Task;
19   typedef RunQueue<Task, 1024> Queue;
20 
21   ThreadPoolTempl(int num_threads, Environment env = Environment())
ThreadPoolTempl(num_threads,true,env)22       : ThreadPoolTempl(num_threads, true, env) {}
23 
24   ThreadPoolTempl(int num_threads, bool allow_spinning,
25                   Environment env = Environment())
env_(env)26       : env_(env),
27         num_threads_(num_threads),
28         allow_spinning_(allow_spinning),
29         thread_data_(num_threads),
30         all_coprimes_(num_threads),
31         waiters_(num_threads),
32         global_steal_partition_(EncodePartition(0, num_threads_)),
33         blocked_(0),
34         spinning_(0),
35         done_(false),
36         cancelled_(false),
37         ec_(waiters_) {
38     waiters_.resize(num_threads_);
39     // Calculate coprimes of all numbers [1, num_threads].
40     // Coprimes are used for random walks over all threads in Steal
41     // and NonEmptyQueueIndex. Iteration is based on the fact that if we take
42     // a random starting thread index t and calculate num_threads - 1 subsequent
43     // indices as (t + coprime) % num_threads, we will cover all threads without
44     // repetitions (effectively getting a presudo-random permutation of thread
45     // indices).
46     eigen_plain_assert(num_threads_ < kMaxThreads);
47     for (int i = 1; i <= num_threads_; ++i) {
48       all_coprimes_.emplace_back(i);
49       ComputeCoprimes(i, &all_coprimes_.back());
50     }
51 #ifndef EIGEN_THREAD_LOCAL
52     init_barrier_.reset(new Barrier(num_threads_));
53 #endif
54     thread_data_.resize(num_threads_);
55     for (int i = 0; i < num_threads_; i++) {
56       SetStealPartition(i, EncodePartition(0, num_threads_));
57       thread_data_[i].thread.reset(
58           env_.CreateThread([this, i]() { WorkerLoop(i); }));
59     }
60 #ifndef EIGEN_THREAD_LOCAL
61     // Wait for workers to initialize per_thread_map_. Otherwise we might race
62     // with them in Schedule or CurrentThreadId.
63     init_barrier_->Wait();
64 #endif
65   }
66 
~ThreadPoolTempl()67   ~ThreadPoolTempl() {
68     done_ = true;
69 
70     // Now if all threads block without work, they will start exiting.
71     // But note that threads can continue to work arbitrary long,
72     // block, submit new work, unblock and otherwise live full life.
73     if (!cancelled_) {
74       ec_.Notify(true);
75     } else {
76       // Since we were cancelled, there might be entries in the queues.
77       // Empty them to prevent their destructor from asserting.
78       for (size_t i = 0; i < thread_data_.size(); i++) {
79         thread_data_[i].queue.Flush();
80       }
81     }
82     // Join threads explicitly (by destroying) to avoid destruction order within
83     // this class.
84     for (size_t i = 0; i < thread_data_.size(); ++i)
85       thread_data_[i].thread.reset();
86   }
87 
SetStealPartitions(const std::vector<std::pair<unsigned,unsigned>> & partitions)88   void SetStealPartitions(const std::vector<std::pair<unsigned, unsigned>>& partitions) {
89     eigen_plain_assert(partitions.size() == static_cast<std::size_t>(num_threads_));
90 
91     // Pass this information to each thread queue.
92     for (int i = 0; i < num_threads_; i++) {
93       const auto& pair = partitions[i];
94       unsigned start = pair.first, end = pair.second;
95       AssertBounds(start, end);
96       unsigned val = EncodePartition(start, end);
97       SetStealPartition(i, val);
98     }
99   }
100 
Schedule(std::function<void ()> fn)101   void Schedule(std::function<void()> fn) EIGEN_OVERRIDE {
102     ScheduleWithHint(std::move(fn), 0, num_threads_);
103   }
104 
ScheduleWithHint(std::function<void ()> fn,int start,int limit)105   void ScheduleWithHint(std::function<void()> fn, int start,
106                         int limit) override {
107     Task t = env_.CreateTask(std::move(fn));
108     PerThread* pt = GetPerThread();
109     if (pt->pool == this) {
110       // Worker thread of this pool, push onto the thread's queue.
111       Queue& q = thread_data_[pt->thread_id].queue;
112       t = q.PushFront(std::move(t));
113     } else {
114       // A free-standing thread (or worker of another pool), push onto a random
115       // queue.
116       eigen_plain_assert(start < limit);
117       eigen_plain_assert(limit <= num_threads_);
118       int num_queues = limit - start;
119       int rnd = Rand(&pt->rand) % num_queues;
120       eigen_plain_assert(start + rnd < limit);
121       Queue& q = thread_data_[start + rnd].queue;
122       t = q.PushBack(std::move(t));
123     }
124     // Note: below we touch this after making w available to worker threads.
125     // Strictly speaking, this can lead to a racy-use-after-free. Consider that
126     // Schedule is called from a thread that is neither main thread nor a worker
127     // thread of this pool. Then, execution of w directly or indirectly
128     // completes overall computations, which in turn leads to destruction of
129     // this. We expect that such scenario is prevented by program, that is,
130     // this is kept alive while any threads can potentially be in Schedule.
131     if (!t.f) {
132       ec_.Notify(false);
133     } else {
134       env_.ExecuteTask(t);  // Push failed, execute directly.
135     }
136   }
137 
Cancel()138   void Cancel() EIGEN_OVERRIDE {
139     cancelled_ = true;
140     done_ = true;
141 
142     // Let each thread know it's been cancelled.
143 #ifdef EIGEN_THREAD_ENV_SUPPORTS_CANCELLATION
144     for (size_t i = 0; i < thread_data_.size(); i++) {
145       thread_data_[i].thread->OnCancel();
146     }
147 #endif
148 
149     // Wake up the threads without work to let them exit on their own.
150     ec_.Notify(true);
151   }
152 
NumThreads()153   int NumThreads() const EIGEN_FINAL { return num_threads_; }
154 
CurrentThreadId()155   int CurrentThreadId() const EIGEN_FINAL {
156     const PerThread* pt = const_cast<ThreadPoolTempl*>(this)->GetPerThread();
157     if (pt->pool == this) {
158       return pt->thread_id;
159     } else {
160       return -1;
161     }
162   }
163 
164  private:
165   // Create a single atomic<int> that encodes start and limit information for
166   // each thread.
167   // We expect num_threads_ < 65536, so we can store them in a single
168   // std::atomic<unsigned>.
169   // Exposed publicly as static functions so that external callers can reuse
170   // this encode/decode logic for maintaining their own thread-safe copies of
171   // scheduling and steal domain(s).
172   static const int kMaxPartitionBits = 16;
173   static const int kMaxThreads = 1 << kMaxPartitionBits;
174 
EncodePartition(unsigned start,unsigned limit)175   inline unsigned EncodePartition(unsigned start, unsigned limit) {
176     return (start << kMaxPartitionBits) | limit;
177   }
178 
DecodePartition(unsigned val,unsigned * start,unsigned * limit)179   inline void DecodePartition(unsigned val, unsigned* start, unsigned* limit) {
180     *limit = val & (kMaxThreads - 1);
181     val >>= kMaxPartitionBits;
182     *start = val;
183   }
184 
AssertBounds(int start,int end)185   void AssertBounds(int start, int end) {
186     eigen_plain_assert(start >= 0);
187     eigen_plain_assert(start < end);  // non-zero sized partition
188     eigen_plain_assert(end <= num_threads_);
189   }
190 
SetStealPartition(size_t i,unsigned val)191   inline void SetStealPartition(size_t i, unsigned val) {
192     thread_data_[i].steal_partition.store(val, std::memory_order_relaxed);
193   }
194 
GetStealPartition(int i)195   inline unsigned GetStealPartition(int i) {
196     return thread_data_[i].steal_partition.load(std::memory_order_relaxed);
197   }
198 
ComputeCoprimes(int N,MaxSizeVector<unsigned> * coprimes)199   void ComputeCoprimes(int N, MaxSizeVector<unsigned>* coprimes) {
200     for (int i = 1; i <= N; i++) {
201       unsigned a = i;
202       unsigned b = N;
203       // If GCD(a, b) == 1, then a and b are coprimes.
204       while (b != 0) {
205         unsigned tmp = a;
206         a = b;
207         b = tmp % b;
208       }
209       if (a == 1) {
210         coprimes->push_back(i);
211       }
212     }
213   }
214 
215   typedef typename Environment::EnvThread Thread;
216 
217   struct PerThread {
PerThreadPerThread218     constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) {}
219     ThreadPoolTempl* pool;  // Parent pool, or null for normal threads.
220     uint64_t rand;          // Random generator state.
221     int thread_id;          // Worker thread index in pool.
222 #ifndef EIGEN_THREAD_LOCAL
223     // Prevent false sharing.
224     char pad_[128];
225 #endif
226   };
227 
228   struct ThreadData {
ThreadDataThreadData229     constexpr ThreadData() : thread(), steal_partition(0), queue() {}
230     std::unique_ptr<Thread> thread;
231     std::atomic<unsigned> steal_partition;
232     Queue queue;
233   };
234 
235   Environment env_;
236   const int num_threads_;
237   const bool allow_spinning_;
238   MaxSizeVector<ThreadData> thread_data_;
239   MaxSizeVector<MaxSizeVector<unsigned>> all_coprimes_;
240   MaxSizeVector<EventCount::Waiter> waiters_;
241   unsigned global_steal_partition_;
242   std::atomic<unsigned> blocked_;
243   std::atomic<bool> spinning_;
244   std::atomic<bool> done_;
245   std::atomic<bool> cancelled_;
246   EventCount ec_;
247 #ifndef EIGEN_THREAD_LOCAL
248   std::unique_ptr<Barrier> init_barrier_;
249   std::mutex per_thread_map_mutex_;  // Protects per_thread_map_.
250   std::unordered_map<uint64_t, std::unique_ptr<PerThread>> per_thread_map_;
251 #endif
252 
253   // Main worker thread loop.
WorkerLoop(int thread_id)254   void WorkerLoop(int thread_id) {
255 #ifndef EIGEN_THREAD_LOCAL
256     std::unique_ptr<PerThread> new_pt(new PerThread());
257     per_thread_map_mutex_.lock();
258     bool insertOK = per_thread_map_.emplace(GlobalThreadIdHash(), std::move(new_pt)).second;
259     eigen_plain_assert(insertOK);
260     EIGEN_UNUSED_VARIABLE(insertOK);
261     per_thread_map_mutex_.unlock();
262     init_barrier_->Notify();
263     init_barrier_->Wait();
264 #endif
265     PerThread* pt = GetPerThread();
266     pt->pool = this;
267     pt->rand = GlobalThreadIdHash();
268     pt->thread_id = thread_id;
269     Queue& q = thread_data_[thread_id].queue;
270     EventCount::Waiter* waiter = &waiters_[thread_id];
271     // TODO(dvyukov,rmlarsen): The time spent in NonEmptyQueueIndex() is
272     // proportional to num_threads_ and we assume that new work is scheduled at
273     // a constant rate, so we set spin_count to 5000 / num_threads_. The
274     // constant was picked based on a fair dice roll, tune it.
275     const int spin_count =
276         allow_spinning_ && num_threads_ > 0 ? 5000 / num_threads_ : 0;
277     if (num_threads_ == 1) {
278       // For num_threads_ == 1 there is no point in going through the expensive
279       // steal loop. Moreover, since NonEmptyQueueIndex() calls PopBack() on the
280       // victim queues it might reverse the order in which ops are executed
281       // compared to the order in which they are scheduled, which tends to be
282       // counter-productive for the types of I/O workloads the single thread
283       // pools tend to be used for.
284       while (!cancelled_) {
285         Task t = q.PopFront();
286         for (int i = 0; i < spin_count && !t.f; i++) {
287           if (!cancelled_.load(std::memory_order_relaxed)) {
288             t = q.PopFront();
289           }
290         }
291         if (!t.f) {
292           if (!WaitForWork(waiter, &t)) {
293             return;
294           }
295         }
296         if (t.f) {
297           env_.ExecuteTask(t);
298         }
299       }
300     } else {
301       while (!cancelled_) {
302         Task t = q.PopFront();
303         if (!t.f) {
304           t = LocalSteal();
305           if (!t.f) {
306             t = GlobalSteal();
307             if (!t.f) {
308               // Leave one thread spinning. This reduces latency.
309               if (allow_spinning_ && !spinning_ && !spinning_.exchange(true)) {
310                 for (int i = 0; i < spin_count && !t.f; i++) {
311                   if (!cancelled_.load(std::memory_order_relaxed)) {
312                     t = GlobalSteal();
313                   } else {
314                     return;
315                   }
316                 }
317                 spinning_ = false;
318               }
319               if (!t.f) {
320                 if (!WaitForWork(waiter, &t)) {
321                   return;
322                 }
323               }
324             }
325           }
326         }
327         if (t.f) {
328           env_.ExecuteTask(t);
329         }
330       }
331     }
332   }
333 
334   // Steal tries to steal work from other worker threads in the range [start,
335   // limit) in best-effort manner.
Steal(unsigned start,unsigned limit)336   Task Steal(unsigned start, unsigned limit) {
337     PerThread* pt = GetPerThread();
338     const size_t size = limit - start;
339     unsigned r = Rand(&pt->rand);
340     // Reduce r into [0, size) range, this utilizes trick from
341     // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
342     eigen_plain_assert(all_coprimes_[size - 1].size() < (1<<30));
343     unsigned victim = ((uint64_t)r * (uint64_t)size) >> 32;
344     unsigned index = ((uint64_t) all_coprimes_[size - 1].size() * (uint64_t)r) >> 32;
345     unsigned inc = all_coprimes_[size - 1][index];
346 
347     for (unsigned i = 0; i < size; i++) {
348       eigen_plain_assert(start + victim < limit);
349       Task t = thread_data_[start + victim].queue.PopBack();
350       if (t.f) {
351         return t;
352       }
353       victim += inc;
354       if (victim >= size) {
355         victim -= size;
356       }
357     }
358     return Task();
359   }
360 
361   // Steals work within threads belonging to the partition.
LocalSteal()362   Task LocalSteal() {
363     PerThread* pt = GetPerThread();
364     unsigned partition = GetStealPartition(pt->thread_id);
365     // If thread steal partition is the same as global partition, there is no
366     // need to go through the steal loop twice.
367     if (global_steal_partition_ == partition) return Task();
368     unsigned start, limit;
369     DecodePartition(partition, &start, &limit);
370     AssertBounds(start, limit);
371 
372     return Steal(start, limit);
373   }
374 
375   // Steals work from any other thread in the pool.
GlobalSteal()376   Task GlobalSteal() {
377     return Steal(0, num_threads_);
378   }
379 
380 
381   // WaitForWork blocks until new work is available (returns true), or if it is
382   // time to exit (returns false). Can optionally return a task to execute in t
383   // (in such case t.f != nullptr on return).
WaitForWork(EventCount::Waiter * waiter,Task * t)384   bool WaitForWork(EventCount::Waiter* waiter, Task* t) {
385     eigen_plain_assert(!t->f);
386     // We already did best-effort emptiness check in Steal, so prepare for
387     // blocking.
388     ec_.Prewait();
389     // Now do a reliable emptiness check.
390     int victim = NonEmptyQueueIndex();
391     if (victim != -1) {
392       ec_.CancelWait();
393       if (cancelled_) {
394         return false;
395       } else {
396         *t = thread_data_[victim].queue.PopBack();
397         return true;
398       }
399     }
400     // Number of blocked threads is used as termination condition.
401     // If we are shutting down and all worker threads blocked without work,
402     // that's we are done.
403     blocked_++;
404     // TODO is blocked_ required to be unsigned?
405     if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
406       ec_.CancelWait();
407       // Almost done, but need to re-check queues.
408       // Consider that all queues are empty and all worker threads are preempted
409       // right after incrementing blocked_ above. Now a free-standing thread
410       // submits work and calls destructor (which sets done_). If we don't
411       // re-check queues, we will exit leaving the work unexecuted.
412       if (NonEmptyQueueIndex() != -1) {
413         // Note: we must not pop from queues before we decrement blocked_,
414         // otherwise the following scenario is possible. Consider that instead
415         // of checking for emptiness we popped the only element from queues.
416         // Now other worker threads can start exiting, which is bad if the
417         // work item submits other work. So we just check emptiness here,
418         // which ensures that all worker threads exit at the same time.
419         blocked_--;
420         return true;
421       }
422       // Reached stable termination state.
423       ec_.Notify(true);
424       return false;
425     }
426     ec_.CommitWait(waiter);
427     blocked_--;
428     return true;
429   }
430 
NonEmptyQueueIndex()431   int NonEmptyQueueIndex() {
432     PerThread* pt = GetPerThread();
433     // We intentionally design NonEmptyQueueIndex to steal work from
434     // anywhere in the queue so threads don't block in WaitForWork() forever
435     // when all threads in their partition go to sleep. Steal is still local.
436     const size_t size = thread_data_.size();
437     unsigned r = Rand(&pt->rand);
438     unsigned inc = all_coprimes_[size - 1][r % all_coprimes_[size - 1].size()];
439     unsigned victim = r % size;
440     for (unsigned i = 0; i < size; i++) {
441       if (!thread_data_[victim].queue.Empty()) {
442         return victim;
443       }
444       victim += inc;
445       if (victim >= size) {
446         victim -= size;
447       }
448     }
449     return -1;
450   }
451 
GlobalThreadIdHash()452   static EIGEN_STRONG_INLINE uint64_t GlobalThreadIdHash() {
453     return std::hash<std::thread::id>()(std::this_thread::get_id());
454   }
455 
GetPerThread()456   EIGEN_STRONG_INLINE PerThread* GetPerThread() {
457 #ifndef EIGEN_THREAD_LOCAL
458     static PerThread dummy;
459     auto it = per_thread_map_.find(GlobalThreadIdHash());
460     if (it == per_thread_map_.end()) {
461       return &dummy;
462     } else {
463       return it->second.get();
464     }
465 #else
466     EIGEN_THREAD_LOCAL PerThread per_thread_;
467     PerThread* pt = &per_thread_;
468     return pt;
469 #endif
470   }
471 
Rand(uint64_t * state)472   static EIGEN_STRONG_INLINE unsigned Rand(uint64_t* state) {
473     uint64_t current = *state;
474     // Update the internal state
475     *state = current * 6364136223846793005ULL + 0xda3e39cb94b95bdbULL;
476     // Generate the random output (using the PCG-XSH-RS scheme)
477     return static_cast<unsigned>((current ^ (current >> 22)) >>
478                                  (22 + (current >> 61)));
479   }
480 };
481 
482 typedef ThreadPoolTempl<StlThreadEnvironment> ThreadPool;
483 
484 }  // namespace Eigen
485 
486 #endif  // EIGEN_CXX11_THREADPOOL_NONBLOCKING_THREAD_POOL_H
487