1 // Copyright 2022 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 //! Multi-thread worker. 6 7 #![deny(missing_docs)] 8 9 use std::collections::VecDeque; 10 use std::sync::atomic::AtomicBool; 11 use std::sync::atomic::Ordering; 12 use std::sync::Arc; 13 use std::thread; 14 use std::time::Duration; 15 16 use anyhow::Context; 17 use base::error; 18 use base::Event; 19 use base::EventWaitResult; 20 use sync::Condvar; 21 use sync::Mutex; 22 23 /// Task to run on the worker threads. 24 pub trait Task { 25 /// Executes the task. execute(self)26 fn execute(self); 27 } 28 29 /// Multi thread based worker executing a single type [Task]. 30 /// 31 /// See the doc of [Channel] as well for the behaviors of it. 32 pub struct Worker<T> { 33 /// Shared [Channel] with the worker threads. 34 pub channel: Arc<Channel<T>>, 35 handles: Vec<thread::JoinHandle<()>>, 36 } 37 38 impl<T: Task + Send + 'static> Worker<T> { 39 /// Spawns the numbers of worker threads. new(len_channel: usize, n_workers: usize) -> Self40 pub fn new(len_channel: usize, n_workers: usize) -> Self { 41 let channel = Arc::new(Channel::<T>::new(len_channel, n_workers)); 42 let mut handles = Vec::with_capacity(n_workers); 43 for _ in 0..n_workers { 44 let context = channel.clone(); 45 let handle = thread::spawn(move || { 46 Self::worker_thread(context); 47 }); 48 handles.push(handle); 49 } 50 Self { channel, handles } 51 } 52 worker_thread(context: Arc<Channel<T>>)53 fn worker_thread(context: Arc<Channel<T>>) { 54 while let Some(task) = context.pop() { 55 task.execute(); 56 } 57 } 58 59 /// Closes the channel and wait for worker threads shutdown. 60 /// 61 /// This also waits for all the tasks in the channel to be executed. close(self)62 pub fn close(self) { 63 self.channel.close(); 64 for handle in self.handles { 65 match handle.join() { 66 Ok(()) => {} 67 Err(e) => { 68 error!("failed to wait for worker thread: {:?}", e); 69 } 70 } 71 } 72 } 73 } 74 75 /// MPMC (Multi Producers Multi Consumers) queue integrated with [Worker]. 76 /// 77 /// [Channel] offers [Channel::wait_complete()] to guarantee all the tasks are executed. 78 /// 79 /// This only exposes methods for producers. 80 pub struct Channel<T> { 81 state: Mutex<ChannelState<T>>, 82 consumer_wait: Condvar, 83 producer_wait: Condvar, 84 n_consumers: usize, 85 } 86 87 impl<T> Channel<T> { new(len: usize, n_consumers: usize) -> Self88 fn new(len: usize, n_consumers: usize) -> Self { 89 Self { 90 state: Mutex::new(ChannelState::new(len)), 91 consumer_wait: Condvar::new(), 92 producer_wait: Condvar::new(), 93 n_consumers, 94 } 95 } 96 close(&self)97 fn close(&self) { 98 let mut state = self.state.lock(); 99 state.is_closed = true; 100 self.consumer_wait.notify_all(); 101 self.producer_wait.notify_all(); 102 } 103 104 /// Pops a task from the channel. 105 /// 106 /// If the queue is closed and also **empty**, this returns [None]. This returns all the tasks 107 /// in the queue even while this is closed. 108 #[inline] pop(&self) -> Option<T>109 fn pop(&self) -> Option<T> { 110 let mut state = self.state.lock(); 111 loop { 112 let was_full = state.queue.len() == state.capacity; 113 if let Some(item) = state.queue.pop_front() { 114 if was_full { 115 // notification for a producer waiting for `push()`. 116 self.producer_wait.notify_one(); 117 } 118 return Some(item); 119 } else { 120 if state.is_closed { 121 return None; 122 } 123 state.n_waiting += 1; 124 if state.n_waiting == self.n_consumers { 125 // notification for producers waiting for `wait_complete()`. 126 self.producer_wait.notify_all(); 127 } 128 state = self.consumer_wait.wait(state); 129 state.n_waiting -= 1; 130 } 131 } 132 } 133 134 /// Push a task. 135 /// 136 /// This blocks if the channel is full. 137 /// 138 /// If the channel is closed, this returns `false`. push(&self, item: T) -> bool139 pub fn push(&self, item: T) -> bool { 140 let mut state = self.state.lock(); 141 // Wait until the queue has room to push a task. 142 while state.queue.len() == state.capacity { 143 if state.is_closed { 144 return false; 145 } 146 state = self.producer_wait.wait(state); 147 } 148 if state.is_closed { 149 return false; 150 } 151 state.queue.push_back(item); 152 self.consumer_wait.notify_one(); 153 true 154 } 155 156 /// Wait until all the tasks have been executed. 157 /// 158 /// This guarantees that all the tasks in this channel are not only consumed but also executed. wait_complete(&self)159 pub fn wait_complete(&self) { 160 let mut state = self.state.lock(); 161 while !(state.queue.is_empty() && state.n_waiting == self.n_consumers) { 162 state = self.producer_wait.wait(state); 163 } 164 } 165 } 166 167 struct ChannelState<T> { 168 queue: VecDeque<T>, 169 capacity: usize, 170 n_waiting: usize, 171 is_closed: bool, 172 } 173 174 impl<T> ChannelState<T> { new(capacity: usize) -> Self175 fn new(capacity: usize) -> Self { 176 Self { 177 queue: VecDeque::with_capacity(capacity), 178 capacity, 179 n_waiting: 0, 180 is_closed: false, 181 } 182 } 183 } 184 185 /// The event channel for background jobs. 186 /// 187 /// This sends an abort request from the main thread to the job thread via atomic boolean flag. 188 /// 189 /// This notifies the main thread that the job thread is completed via [Event]. 190 pub struct BackgroundJobControl { 191 event: Event, 192 abort_flag: AtomicBool, 193 } 194 195 impl BackgroundJobControl { 196 /// Creates [BackgroundJobControl]. new() -> anyhow::Result<Self>197 pub fn new() -> anyhow::Result<Self> { 198 Ok(Self { 199 event: Event::new()?, 200 abort_flag: AtomicBool::new(false), 201 }) 202 } 203 204 /// Creates [BackgroundJob]. new_job(&self) -> BackgroundJob<'_>205 pub fn new_job(&self) -> BackgroundJob<'_> { 206 BackgroundJob { 207 event: &self.event, 208 abort_flag: &self.abort_flag, 209 } 210 } 211 212 /// Abort the background job. abort(&self)213 pub fn abort(&self) { 214 self.abort_flag.store(true, Ordering::Release); 215 } 216 217 /// Reset the internal state for a next job. 218 /// 219 /// Returns false, if the event is already reset and no event exists. reset(&self) -> anyhow::Result<bool>220 pub fn reset(&self) -> anyhow::Result<bool> { 221 self.abort_flag.store(false, Ordering::Release); 222 Ok(matches!( 223 self.event 224 .wait_timeout(Duration::ZERO) 225 .context("failed to get job complete event")?, 226 EventWaitResult::Signaled 227 )) 228 } 229 230 /// Returns the event to notify the completion of background job. get_completion_event(&self) -> &Event231 pub fn get_completion_event(&self) -> &Event { 232 &self.event 233 } 234 } 235 236 /// Background job context. 237 /// 238 /// When dropped, this sends an event to the main thread via [Event]. 239 pub struct BackgroundJob<'a> { 240 event: &'a Event, 241 abort_flag: &'a AtomicBool, 242 } 243 244 impl BackgroundJob<'_> { 245 /// Returns whether the background job is aborted or not. is_aborted(&self) -> bool246 pub fn is_aborted(&self) -> bool { 247 self.abort_flag.load(Ordering::Acquire) 248 } 249 } 250 251 impl Drop for BackgroundJob<'_> { drop(&mut self)252 fn drop(&mut self) { 253 self.event.signal().expect("send job complete event"); 254 } 255 } 256 257 #[cfg(test)] 258 mod tests { 259 use std::time::Duration; 260 261 use super::*; 262 263 #[derive(Clone, Copy)] 264 struct Context { 265 n_consume: usize, 266 n_executed: usize, 267 } 268 269 struct FakeTask { 270 context: Mutex<Context>, 271 waker: Condvar, 272 } 273 274 impl FakeTask { new() -> Arc<Self>275 fn new() -> Arc<Self> { 276 Arc::new(Self { 277 context: Mutex::new(Context { 278 n_consume: 0, 279 n_executed: 0, 280 }), 281 waker: Condvar::new(), 282 }) 283 } 284 consume(&self, count: usize)285 fn consume(&self, count: usize) { 286 let mut context = self.context.lock(); 287 context.n_consume += count; 288 self.waker.notify_all(); 289 } 290 n_executed(&self) -> usize291 fn n_executed(&self) -> usize { 292 self.context.lock().n_executed 293 } 294 } 295 296 impl Task for Arc<FakeTask> { execute(self)297 fn execute(self) { 298 let mut context = self.context.lock(); 299 while context.n_consume == 0 { 300 context = self.waker.wait(context); 301 } 302 context.n_consume -= 1; 303 context.n_executed += 1; 304 } 305 } 306 wait_thread_with_timeout<T>(join_handle: thread::JoinHandle<T>, timeout_millis: u64) -> T307 fn wait_thread_with_timeout<T>(join_handle: thread::JoinHandle<T>, timeout_millis: u64) -> T { 308 for _ in 0..timeout_millis { 309 if join_handle.is_finished() { 310 return join_handle.join().unwrap(); 311 } 312 thread::sleep(Duration::from_millis(1)); 313 } 314 panic!("thread join timeout"); 315 } 316 poll_until_with_timeout<F>(f: F, timeout_millis: u64) where F: Fn() -> bool,317 fn poll_until_with_timeout<F>(f: F, timeout_millis: u64) 318 where 319 F: Fn() -> bool, 320 { 321 for _ in 0..timeout_millis { 322 if f() { 323 break; 324 } 325 thread::sleep(Duration::from_millis(1)); 326 } 327 } 328 329 #[test] test_worker()330 fn test_worker() { 331 let worker = Worker::new(2, 4); 332 let task = FakeTask::new(); 333 let channel = worker.channel.clone(); 334 335 for _ in 0..4 { 336 assert!(channel.push(task.clone())); 337 } 338 339 assert_eq!(task.n_executed(), 0); 340 task.consume(4); 341 worker.channel.wait_complete(); 342 assert_eq!(task.n_executed(), 4); 343 worker.close(); 344 } 345 346 #[test] test_worker_push_after_close()347 fn test_worker_push_after_close() { 348 let worker = Worker::new(2, 4); 349 let task = FakeTask::new(); 350 let channel = worker.channel.clone(); 351 352 worker.close(); 353 354 assert!(!channel.push(task)); 355 } 356 357 #[test] test_worker_push_block()358 fn test_worker_push_block() { 359 let worker = Worker::new(2, 4); 360 let task = FakeTask::new(); 361 let channel = worker.channel.clone(); 362 363 let task_cloned = task.clone(); 364 // push tasks on another thread to avoid blocking forever 365 wait_thread_with_timeout( 366 thread::spawn(move || { 367 for _ in 0..6 { 368 assert!(channel.push(task_cloned.clone())); 369 } 370 }), 371 100, 372 ); 373 let channel = worker.channel.clone(); 374 let task_cloned = task.clone(); 375 let push_thread = thread::spawn(move || { 376 assert!(channel.push(task_cloned)); 377 }); 378 thread::sleep(Duration::from_millis(10)); 379 assert!(!push_thread.is_finished()); 380 381 task.consume(1); 382 wait_thread_with_timeout(push_thread, 100); 383 384 task.consume(6); 385 #[allow(clippy::redundant_clone)] 386 let task_clone = task.clone(); 387 poll_until_with_timeout(|| task_clone.n_executed() == 7, 100); 388 assert_eq!(task.n_executed(), 7); 389 worker.close(); 390 } 391 392 #[test] test_worker_close_on_push_blocked()393 fn test_worker_close_on_push_blocked() { 394 let worker = Worker::new(2, 4); 395 let task = FakeTask::new(); 396 let channel = worker.channel.clone(); 397 398 let task_cloned = task.clone(); 399 // push tasks on another thread to avoid blocking forever 400 wait_thread_with_timeout( 401 thread::spawn(move || { 402 for _ in 0..6 { 403 assert!(channel.push(task_cloned.clone())); 404 } 405 }), 406 100, 407 ); 408 let channel = worker.channel.clone(); 409 let task_cloned = task.clone(); 410 let push_thread = thread::spawn(move || channel.push(task_cloned)); 411 // sleep to run push_thread. 412 thread::sleep(Duration::from_millis(10)); 413 // close blocks until all the task are executed. 414 let close_thread = thread::spawn(move || { 415 worker.close(); 416 }); 417 let push_result = wait_thread_with_timeout(push_thread, 100); 418 // push fails. 419 assert!(!push_result); 420 421 // cleanup 422 task.consume(6); 423 wait_thread_with_timeout(close_thread, 100); 424 } 425 426 #[test] new_background_job_event()427 fn new_background_job_event() { 428 assert!(BackgroundJobControl::new().is_ok()); 429 } 430 431 #[test] background_job_is_not_aborted_default()432 fn background_job_is_not_aborted_default() { 433 let event = BackgroundJobControl::new().unwrap(); 434 435 let job = event.new_job(); 436 437 assert!(!job.is_aborted()); 438 } 439 440 #[test] abort_background_job()441 fn abort_background_job() { 442 let event = BackgroundJobControl::new().unwrap(); 443 444 let job = event.new_job(); 445 event.abort(); 446 447 assert!(job.is_aborted()); 448 } 449 450 #[test] reset_background_job()451 fn reset_background_job() { 452 let event = BackgroundJobControl::new().unwrap(); 453 454 event.abort(); 455 event.reset().unwrap(); 456 let job = event.new_job(); 457 458 assert!(!job.is_aborted()); 459 } 460 461 #[test] reset_background_job_event()462 fn reset_background_job_event() { 463 let event = BackgroundJobControl::new().unwrap(); 464 465 let job = event.new_job(); 466 drop(job); 467 468 assert!(event.reset().unwrap()); 469 } 470 471 #[test] reset_background_job_event_twice()472 fn reset_background_job_event_twice() { 473 let event = BackgroundJobControl::new().unwrap(); 474 475 let job = event.new_job(); 476 drop(job); 477 478 event.reset().unwrap(); 479 assert!(!event.reset().unwrap()); 480 } 481 482 #[test] reset_background_job_event_no_jobs()483 fn reset_background_job_event_no_jobs() { 484 let event = BackgroundJobControl::new().unwrap(); 485 486 assert!(!event.reset().unwrap()); 487 } 488 } 489