xref: /aosp_15_r20/external/crosvm/swap/src/worker.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Multi-thread worker.
6 
7 #![deny(missing_docs)]
8 
9 use std::collections::VecDeque;
10 use std::sync::atomic::AtomicBool;
11 use std::sync::atomic::Ordering;
12 use std::sync::Arc;
13 use std::thread;
14 use std::time::Duration;
15 
16 use anyhow::Context;
17 use base::error;
18 use base::Event;
19 use base::EventWaitResult;
20 use sync::Condvar;
21 use sync::Mutex;
22 
23 /// Task to run on the worker threads.
24 pub trait Task {
25     /// Executes the task.
execute(self)26     fn execute(self);
27 }
28 
29 /// Multi thread based worker executing a single type [Task].
30 ///
31 /// See the doc of [Channel] as well for the behaviors of it.
32 pub struct Worker<T> {
33     /// Shared [Channel] with the worker threads.
34     pub channel: Arc<Channel<T>>,
35     handles: Vec<thread::JoinHandle<()>>,
36 }
37 
38 impl<T: Task + Send + 'static> Worker<T> {
39     /// Spawns the numbers of worker threads.
new(len_channel: usize, n_workers: usize) -> Self40     pub fn new(len_channel: usize, n_workers: usize) -> Self {
41         let channel = Arc::new(Channel::<T>::new(len_channel, n_workers));
42         let mut handles = Vec::with_capacity(n_workers);
43         for _ in 0..n_workers {
44             let context = channel.clone();
45             let handle = thread::spawn(move || {
46                 Self::worker_thread(context);
47             });
48             handles.push(handle);
49         }
50         Self { channel, handles }
51     }
52 
worker_thread(context: Arc<Channel<T>>)53     fn worker_thread(context: Arc<Channel<T>>) {
54         while let Some(task) = context.pop() {
55             task.execute();
56         }
57     }
58 
59     /// Closes the channel and wait for worker threads shutdown.
60     ///
61     /// This also waits for all the tasks in the channel to be executed.
close(self)62     pub fn close(self) {
63         self.channel.close();
64         for handle in self.handles {
65             match handle.join() {
66                 Ok(()) => {}
67                 Err(e) => {
68                     error!("failed to wait for worker thread: {:?}", e);
69                 }
70             }
71         }
72     }
73 }
74 
75 /// MPMC (Multi Producers Multi Consumers) queue integrated with [Worker].
76 ///
77 /// [Channel] offers [Channel::wait_complete()] to guarantee all the tasks are executed.
78 ///
79 /// This only exposes methods for producers.
80 pub struct Channel<T> {
81     state: Mutex<ChannelState<T>>,
82     consumer_wait: Condvar,
83     producer_wait: Condvar,
84     n_consumers: usize,
85 }
86 
87 impl<T> Channel<T> {
new(len: usize, n_consumers: usize) -> Self88     fn new(len: usize, n_consumers: usize) -> Self {
89         Self {
90             state: Mutex::new(ChannelState::new(len)),
91             consumer_wait: Condvar::new(),
92             producer_wait: Condvar::new(),
93             n_consumers,
94         }
95     }
96 
close(&self)97     fn close(&self) {
98         let mut state = self.state.lock();
99         state.is_closed = true;
100         self.consumer_wait.notify_all();
101         self.producer_wait.notify_all();
102     }
103 
104     /// Pops a task from the channel.
105     ///
106     /// If the queue is closed and also **empty**, this returns [None]. This returns all the tasks
107     /// in the queue even while this is closed.
108     #[inline]
pop(&self) -> Option<T>109     fn pop(&self) -> Option<T> {
110         let mut state = self.state.lock();
111         loop {
112             let was_full = state.queue.len() == state.capacity;
113             if let Some(item) = state.queue.pop_front() {
114                 if was_full {
115                     // notification for a producer waiting for `push()`.
116                     self.producer_wait.notify_one();
117                 }
118                 return Some(item);
119             } else {
120                 if state.is_closed {
121                     return None;
122                 }
123                 state.n_waiting += 1;
124                 if state.n_waiting == self.n_consumers {
125                     // notification for producers waiting for `wait_complete()`.
126                     self.producer_wait.notify_all();
127                 }
128                 state = self.consumer_wait.wait(state);
129                 state.n_waiting -= 1;
130             }
131         }
132     }
133 
134     /// Push a task.
135     ///
136     /// This blocks if the channel is full.
137     ///
138     /// If the channel is closed, this returns `false`.
push(&self, item: T) -> bool139     pub fn push(&self, item: T) -> bool {
140         let mut state = self.state.lock();
141         // Wait until the queue has room to push a task.
142         while state.queue.len() == state.capacity {
143             if state.is_closed {
144                 return false;
145             }
146             state = self.producer_wait.wait(state);
147         }
148         if state.is_closed {
149             return false;
150         }
151         state.queue.push_back(item);
152         self.consumer_wait.notify_one();
153         true
154     }
155 
156     /// Wait until all the tasks have been executed.
157     ///
158     /// This guarantees that all the tasks in this channel are not only consumed but also executed.
wait_complete(&self)159     pub fn wait_complete(&self) {
160         let mut state = self.state.lock();
161         while !(state.queue.is_empty() && state.n_waiting == self.n_consumers) {
162             state = self.producer_wait.wait(state);
163         }
164     }
165 }
166 
167 struct ChannelState<T> {
168     queue: VecDeque<T>,
169     capacity: usize,
170     n_waiting: usize,
171     is_closed: bool,
172 }
173 
174 impl<T> ChannelState<T> {
new(capacity: usize) -> Self175     fn new(capacity: usize) -> Self {
176         Self {
177             queue: VecDeque::with_capacity(capacity),
178             capacity,
179             n_waiting: 0,
180             is_closed: false,
181         }
182     }
183 }
184 
185 /// The event channel for background jobs.
186 ///
187 /// This sends an abort request from the main thread to the job thread via atomic boolean flag.
188 ///
189 /// This notifies the main thread that the job thread is completed via [Event].
190 pub struct BackgroundJobControl {
191     event: Event,
192     abort_flag: AtomicBool,
193 }
194 
195 impl BackgroundJobControl {
196     /// Creates [BackgroundJobControl].
new() -> anyhow::Result<Self>197     pub fn new() -> anyhow::Result<Self> {
198         Ok(Self {
199             event: Event::new()?,
200             abort_flag: AtomicBool::new(false),
201         })
202     }
203 
204     /// Creates [BackgroundJob].
new_job(&self) -> BackgroundJob<'_>205     pub fn new_job(&self) -> BackgroundJob<'_> {
206         BackgroundJob {
207             event: &self.event,
208             abort_flag: &self.abort_flag,
209         }
210     }
211 
212     /// Abort the background job.
abort(&self)213     pub fn abort(&self) {
214         self.abort_flag.store(true, Ordering::Release);
215     }
216 
217     /// Reset the internal state for a next job.
218     ///
219     /// Returns false, if the event is already reset and no event exists.
reset(&self) -> anyhow::Result<bool>220     pub fn reset(&self) -> anyhow::Result<bool> {
221         self.abort_flag.store(false, Ordering::Release);
222         Ok(matches!(
223             self.event
224                 .wait_timeout(Duration::ZERO)
225                 .context("failed to get job complete event")?,
226             EventWaitResult::Signaled
227         ))
228     }
229 
230     /// Returns the event to notify the completion of background job.
get_completion_event(&self) -> &Event231     pub fn get_completion_event(&self) -> &Event {
232         &self.event
233     }
234 }
235 
236 /// Background job context.
237 ///
238 /// When dropped, this sends an event to the main thread via [Event].
239 pub struct BackgroundJob<'a> {
240     event: &'a Event,
241     abort_flag: &'a AtomicBool,
242 }
243 
244 impl BackgroundJob<'_> {
245     /// Returns whether the background job is aborted or not.
is_aborted(&self) -> bool246     pub fn is_aborted(&self) -> bool {
247         self.abort_flag.load(Ordering::Acquire)
248     }
249 }
250 
251 impl Drop for BackgroundJob<'_> {
drop(&mut self)252     fn drop(&mut self) {
253         self.event.signal().expect("send job complete event");
254     }
255 }
256 
257 #[cfg(test)]
258 mod tests {
259     use std::time::Duration;
260 
261     use super::*;
262 
263     #[derive(Clone, Copy)]
264     struct Context {
265         n_consume: usize,
266         n_executed: usize,
267     }
268 
269     struct FakeTask {
270         context: Mutex<Context>,
271         waker: Condvar,
272     }
273 
274     impl FakeTask {
new() -> Arc<Self>275         fn new() -> Arc<Self> {
276             Arc::new(Self {
277                 context: Mutex::new(Context {
278                     n_consume: 0,
279                     n_executed: 0,
280                 }),
281                 waker: Condvar::new(),
282             })
283         }
284 
consume(&self, count: usize)285         fn consume(&self, count: usize) {
286             let mut context = self.context.lock();
287             context.n_consume += count;
288             self.waker.notify_all();
289         }
290 
n_executed(&self) -> usize291         fn n_executed(&self) -> usize {
292             self.context.lock().n_executed
293         }
294     }
295 
296     impl Task for Arc<FakeTask> {
execute(self)297         fn execute(self) {
298             let mut context = self.context.lock();
299             while context.n_consume == 0 {
300                 context = self.waker.wait(context);
301             }
302             context.n_consume -= 1;
303             context.n_executed += 1;
304         }
305     }
306 
wait_thread_with_timeout<T>(join_handle: thread::JoinHandle<T>, timeout_millis: u64) -> T307     fn wait_thread_with_timeout<T>(join_handle: thread::JoinHandle<T>, timeout_millis: u64) -> T {
308         for _ in 0..timeout_millis {
309             if join_handle.is_finished() {
310                 return join_handle.join().unwrap();
311             }
312             thread::sleep(Duration::from_millis(1));
313         }
314         panic!("thread join timeout");
315     }
316 
poll_until_with_timeout<F>(f: F, timeout_millis: u64) where F: Fn() -> bool,317     fn poll_until_with_timeout<F>(f: F, timeout_millis: u64)
318     where
319         F: Fn() -> bool,
320     {
321         for _ in 0..timeout_millis {
322             if f() {
323                 break;
324             }
325             thread::sleep(Duration::from_millis(1));
326         }
327     }
328 
329     #[test]
test_worker()330     fn test_worker() {
331         let worker = Worker::new(2, 4);
332         let task = FakeTask::new();
333         let channel = worker.channel.clone();
334 
335         for _ in 0..4 {
336             assert!(channel.push(task.clone()));
337         }
338 
339         assert_eq!(task.n_executed(), 0);
340         task.consume(4);
341         worker.channel.wait_complete();
342         assert_eq!(task.n_executed(), 4);
343         worker.close();
344     }
345 
346     #[test]
test_worker_push_after_close()347     fn test_worker_push_after_close() {
348         let worker = Worker::new(2, 4);
349         let task = FakeTask::new();
350         let channel = worker.channel.clone();
351 
352         worker.close();
353 
354         assert!(!channel.push(task));
355     }
356 
357     #[test]
test_worker_push_block()358     fn test_worker_push_block() {
359         let worker = Worker::new(2, 4);
360         let task = FakeTask::new();
361         let channel = worker.channel.clone();
362 
363         let task_cloned = task.clone();
364         // push tasks on another thread to avoid blocking forever
365         wait_thread_with_timeout(
366             thread::spawn(move || {
367                 for _ in 0..6 {
368                     assert!(channel.push(task_cloned.clone()));
369                 }
370             }),
371             100,
372         );
373         let channel = worker.channel.clone();
374         let task_cloned = task.clone();
375         let push_thread = thread::spawn(move || {
376             assert!(channel.push(task_cloned));
377         });
378         thread::sleep(Duration::from_millis(10));
379         assert!(!push_thread.is_finished());
380 
381         task.consume(1);
382         wait_thread_with_timeout(push_thread, 100);
383 
384         task.consume(6);
385         #[allow(clippy::redundant_clone)]
386         let task_clone = task.clone();
387         poll_until_with_timeout(|| task_clone.n_executed() == 7, 100);
388         assert_eq!(task.n_executed(), 7);
389         worker.close();
390     }
391 
392     #[test]
test_worker_close_on_push_blocked()393     fn test_worker_close_on_push_blocked() {
394         let worker = Worker::new(2, 4);
395         let task = FakeTask::new();
396         let channel = worker.channel.clone();
397 
398         let task_cloned = task.clone();
399         // push tasks on another thread to avoid blocking forever
400         wait_thread_with_timeout(
401             thread::spawn(move || {
402                 for _ in 0..6 {
403                     assert!(channel.push(task_cloned.clone()));
404                 }
405             }),
406             100,
407         );
408         let channel = worker.channel.clone();
409         let task_cloned = task.clone();
410         let push_thread = thread::spawn(move || channel.push(task_cloned));
411         // sleep to run push_thread.
412         thread::sleep(Duration::from_millis(10));
413         // close blocks until all the task are executed.
414         let close_thread = thread::spawn(move || {
415             worker.close();
416         });
417         let push_result = wait_thread_with_timeout(push_thread, 100);
418         // push fails.
419         assert!(!push_result);
420 
421         // cleanup
422         task.consume(6);
423         wait_thread_with_timeout(close_thread, 100);
424     }
425 
426     #[test]
new_background_job_event()427     fn new_background_job_event() {
428         assert!(BackgroundJobControl::new().is_ok());
429     }
430 
431     #[test]
background_job_is_not_aborted_default()432     fn background_job_is_not_aborted_default() {
433         let event = BackgroundJobControl::new().unwrap();
434 
435         let job = event.new_job();
436 
437         assert!(!job.is_aborted());
438     }
439 
440     #[test]
abort_background_job()441     fn abort_background_job() {
442         let event = BackgroundJobControl::new().unwrap();
443 
444         let job = event.new_job();
445         event.abort();
446 
447         assert!(job.is_aborted());
448     }
449 
450     #[test]
reset_background_job()451     fn reset_background_job() {
452         let event = BackgroundJobControl::new().unwrap();
453 
454         event.abort();
455         event.reset().unwrap();
456         let job = event.new_job();
457 
458         assert!(!job.is_aborted());
459     }
460 
461     #[test]
reset_background_job_event()462     fn reset_background_job_event() {
463         let event = BackgroundJobControl::new().unwrap();
464 
465         let job = event.new_job();
466         drop(job);
467 
468         assert!(event.reset().unwrap());
469     }
470 
471     #[test]
reset_background_job_event_twice()472     fn reset_background_job_event_twice() {
473         let event = BackgroundJobControl::new().unwrap();
474 
475         let job = event.new_job();
476         drop(job);
477 
478         event.reset().unwrap();
479         assert!(!event.reset().unwrap());
480     }
481 
482     #[test]
reset_background_job_event_no_jobs()483     fn reset_background_job_event_no_jobs() {
484         let event = BackgroundJobControl::new().unwrap();
485 
486         assert!(!event.reset().unwrap());
487     }
488 }
489