xref: /aosp_15_r20/external/ruy/ruy/thread_pool.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1 /* Copyright 2019 Google LLC. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "ruy/thread_pool.h"
17 
18 #include <atomic>
19 #include <chrono>              // NOLINT(build/c++11)
20 #include <condition_variable>  // NOLINT(build/c++11)
21 #include <cstdint>
22 #include <cstdlib>
23 #include <memory>
24 #include <mutex>   // NOLINT(build/c++11)
25 #include <thread>  // NOLINT(build/c++11)
26 
27 #include "ruy/check_macros.h"
28 #include "ruy/denormal.h"
29 #include "ruy/trace.h"
30 #include "ruy/wait.h"
31 
32 namespace ruy {
33 
34 // A worker thread.
35 class Thread {
36  public:
Thread(BlockingCounter * count_busy_threads,Duration spin_duration)37   explicit Thread(BlockingCounter* count_busy_threads, Duration spin_duration)
38       : state_(State::Startup),
39         count_busy_threads_(count_busy_threads),
40         spin_duration_(spin_duration) {
41     thread_.reset(new std::thread(ThreadFunc, this));
42   }
43 
RequestExitAsSoonAsPossible()44   void RequestExitAsSoonAsPossible() {
45     ChangeStateFromOutsideThread(State::ExitAsSoonAsPossible);
46   }
47 
~Thread()48   ~Thread() {
49     RUY_DCHECK_EQ(state_.load(), State::ExitAsSoonAsPossible);
50     thread_->join();
51   }
52 
53   // Called by an outside thead to give work to the worker thread.
StartWork(Task * task)54   void StartWork(Task* task) {
55     ChangeStateFromOutsideThread(State::HasWork, task);
56   }
57 
58  private:
59   enum class State {
60     Startup,  // The initial state before the thread loop runs.
61     Ready,    // Is not working, has not yet received new work to do.
62     HasWork,  // Has work to do.
63     ExitAsSoonAsPossible  // Should exit at earliest convenience.
64   };
65 
66   // Implements the state_ change to State::Ready, which is where we consume
67   // task_. Only called on the worker thread.
68   // Reads task_, so assumes ordering past any prior writes to task_.
RevertToReadyState()69   void RevertToReadyState() {
70     RUY_TRACE_SCOPE_NAME("Worker thread task");
71     // See task_ member comment for the ordering of accesses.
72     if (task_) {
73       task_->Run();
74       task_ = nullptr;
75     }
76     // No need to notify state_cond_, since only the worker thread ever waits
77     // on it, and we are that thread.
78     // Relaxed order because ordering is already provided by the
79     // count_busy_threads_->DecrementCount() at the next line, since the next
80     // state_ mutation will be to give new work and that won't happen before
81     // the outside thread has finished the current batch with a
82     // count_busy_threads_->Wait().
83     state_.store(State::Ready, std::memory_order_relaxed);
84     count_busy_threads_->DecrementCount();
85   }
86 
87   // Changes State, from outside thread.
88   //
89   // The Task argument is to be used only with new_state==HasWork.
90   // It specifies the Task being handed to this Thread.
91   //
92   // new_task is only used with State::HasWork.
ChangeStateFromOutsideThread(State new_state,Task * new_task=nullptr)93   void ChangeStateFromOutsideThread(State new_state, Task* new_task = nullptr) {
94     RUY_DCHECK(new_state == State::ExitAsSoonAsPossible ||
95                new_state == State::HasWork);
96     RUY_DCHECK((new_task != nullptr) == (new_state == State::HasWork));
97 
98 #ifndef NDEBUG
99     // Debug-only sanity checks based on old_state.
100     State old_state = state_.load();
101     RUY_DCHECK_NE(old_state, new_state);
102     RUY_DCHECK(old_state == State::Ready || old_state == State::HasWork);
103     RUY_DCHECK_NE(old_state, new_state);
104 #endif
105 
106     switch (new_state) {
107       case State::HasWork:
108         // See task_ member comment for the ordering of accesses.
109         RUY_DCHECK(!task_);
110         task_ = new_task;
111         break;
112       case State::ExitAsSoonAsPossible:
113         break;
114       default:
115         abort();
116     }
117     // Release order because the worker thread will read this with acquire
118     // order.
119     state_.store(new_state, std::memory_order_release);
120     state_cond_mutex_.lock();
121     state_cond_.notify_one();  // Only this one worker thread cares.
122     state_cond_mutex_.unlock();
123   }
124 
ThreadFunc(Thread * arg)125   static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
126 
127   // Waits for state_ to be different from State::Ready, and returns that
128   // new value.
GetNewStateOtherThanReady()129   State GetNewStateOtherThanReady() {
130     State new_state;
131     const auto& new_state_not_ready = [this, &new_state]() {
132       new_state = state_.load(std::memory_order_acquire);
133       return new_state != State::Ready;
134     };
135     RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
136     Wait(new_state_not_ready, spin_duration_, &state_cond_, &state_cond_mutex_);
137     return new_state;
138   }
139 
140   // Thread entry point.
ThreadFuncImpl()141   void ThreadFuncImpl() {
142     RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
143     RevertToReadyState();
144 
145     // Suppress denormals to avoid computation inefficiency.
146     ScopedSuppressDenormals suppress_denormals;
147 
148     // Thread loop
149     while (GetNewStateOtherThanReady() == State::HasWork) {
150       RevertToReadyState();
151     }
152 
153     // Thread end. We should only get here if we were told to exit.
154     RUY_DCHECK(state_.load() == State::ExitAsSoonAsPossible);
155   }
156 
157   // The underlying thread. Used to join on destruction.
158   std::unique_ptr<std::thread> thread_;
159 
160   // The task to be worked on.
161   //
162   // The ordering of reads and writes to task_ is as follows.
163   //
164   // 1. The outside thread gives new work by calling
165   //      ChangeStateFromOutsideThread(State::HasWork, new_task);
166   //    That does:
167   //    - a. Write task_ = new_task (non-atomic).
168   //    - b. Store state_ = State::HasWork (memory_order_release).
169   // 2. The worker thread picks up the new state by calling
170   //      GetNewStateOtherThanReady()
171   //    That does:
172   //    - c. Load state (memory_order_acquire).
173   //    The worker thread then reads the new task in RevertToReadyState().
174   //    That does:
175   //    - d. Read task_ (non-atomic).
176   // 3. The worker thread, still in RevertToReadyState(), consumes the task_ and
177   //    does:
178   //    - e. Write task_ = nullptr (non-atomic).
179   //    And then calls Call count_busy_threads_->DecrementCount()
180   //    which does
181   //    - f. Store count_busy_threads_ (memory_order_release).
182   // 4. The outside thread, in ThreadPool::ExecuteImpl, finally waits for worker
183   //    threads by calling count_busy_threads_->Wait(), which does:
184   //    - g. Load count_busy_threads_ (memory_order_acquire).
185   //
186   // Thus the non-atomic write-then-read accesses to task_ (a. -> d.) are
187   // ordered by the release-acquire relationship of accesses to state_ (b. ->
188   // c.), and the non-atomic write accesses to task_ (e. -> a.) are ordered by
189   // the release-acquire relationship of accesses to count_busy_threads_ (f. ->
190   // g.).
191   Task* task_ = nullptr;
192 
193   // Condition variable used by the outside thread to notify the worker thread
194   // of a state change.
195   std::condition_variable state_cond_;
196 
197   // Mutex used to guard state_cond_
198   std::mutex state_cond_mutex_;
199 
200   // The state enum tells if we're currently working, waiting for work, etc.
201   // It is written to from either the outside thread or the worker thread,
202   // in the ChangeState method.
203   // It is only ever read by the worker thread.
204   std::atomic<State> state_;
205 
206   // pointer to the master's thread BlockingCounter object, to notify the
207   // master thread of when this thread switches to the 'Ready' state.
208   BlockingCounter* const count_busy_threads_;
209 
210   // See ThreadPool::spin_duration_.
211   const Duration spin_duration_;
212 };
213 
ExecuteImpl(int task_count,int stride,Task * tasks)214 void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
215   RUY_TRACE_SCOPE_NAME("ThreadPool::Execute");
216   RUY_DCHECK_GE(task_count, 1);
217 
218   // Case of 1 thread: just run the single task on the current thread.
219   if (task_count == 1) {
220     (tasks + 0)->Run();
221     return;
222   }
223 
224   // Task #0 will be run on the current thread.
225   CreateThreads(task_count - 1);
226   count_busy_threads_.Reset(task_count - 1);
227   for (int i = 1; i < task_count; i++) {
228     RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
229     auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
230     threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
231   }
232 
233   RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD);
234   // Execute task #0 immediately on the current thread.
235   (tasks + 0)->Run();
236 
237   RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
238   // Wait for the threads submitted above to finish.
239   count_busy_threads_.Wait(spin_duration_);
240 }
241 
242 // Ensures that the pool has at least the given count of threads.
243 // If any new thread has to be created, this function waits for it to
244 // be ready.
CreateThreads(int threads_count)245 void ThreadPool::CreateThreads(int threads_count) {
246   RUY_DCHECK_GE(threads_count, 0);
247   unsigned int unsigned_threads_count = threads_count;
248   if (threads_.size() >= unsigned_threads_count) {
249     return;
250   }
251   count_busy_threads_.Reset(threads_count - threads_.size());
252   while (threads_.size() < unsigned_threads_count) {
253     threads_.push_back(new Thread(&count_busy_threads_, spin_duration_));
254   }
255   count_busy_threads_.Wait(spin_duration_);
256 }
257 
~ThreadPool()258 ThreadPool::~ThreadPool() {
259   // Send all exit requests upfront so threads can work on them in parallel.
260   for (auto w : threads_) {
261     w->RequestExitAsSoonAsPossible();
262   }
263   for (auto w : threads_) {
264     delete w;
265   }
266 }
267 
268 }  // end namespace ruy
269