1 //! Waking mechanism for threads blocked on channel operations.
2 
3 use std::ptr;
4 use std::sync::atomic::{AtomicBool, Ordering};
5 use std::sync::Mutex;
6 use std::thread::{self, ThreadId};
7 use std::vec::Vec;
8 
9 use crate::context::Context;
10 use crate::select::{Operation, Selected};
11 
12 /// Represents a thread blocked on a specific channel operation.
13 pub(crate) struct Entry {
14     /// The operation.
15     pub(crate) oper: Operation,
16 
17     /// Optional packet.
18     pub(crate) packet: *mut (),
19 
20     /// Context associated with the thread owning this operation.
21     pub(crate) cx: Context,
22 }
23 
24 /// A queue of threads blocked on channel operations.
25 ///
26 /// This data structure is used by threads to register blocking operations and get woken up once
27 /// an operation becomes ready.
28 pub(crate) struct Waker {
29     /// A list of select operations.
30     selectors: Vec<Entry>,
31 
32     /// A list of operations waiting to be ready.
33     observers: Vec<Entry>,
34 }
35 
36 impl Waker {
37     /// Creates a new `Waker`.
38     #[inline]
new() -> Self39     pub(crate) fn new() -> Self {
40         Waker {
41             selectors: Vec::new(),
42             observers: Vec::new(),
43         }
44     }
45 
46     /// Registers a select operation.
47     #[inline]
register(&mut self, oper: Operation, cx: &Context)48     pub(crate) fn register(&mut self, oper: Operation, cx: &Context) {
49         self.register_with_packet(oper, ptr::null_mut(), cx);
50     }
51 
52     /// Registers a select operation and a packet.
53     #[inline]
register_with_packet(&mut self, oper: Operation, packet: *mut (), cx: &Context)54     pub(crate) fn register_with_packet(&mut self, oper: Operation, packet: *mut (), cx: &Context) {
55         self.selectors.push(Entry {
56             oper,
57             packet,
58             cx: cx.clone(),
59         });
60     }
61 
62     /// Unregisters a select operation.
63     #[inline]
unregister(&mut self, oper: Operation) -> Option<Entry>64     pub(crate) fn unregister(&mut self, oper: Operation) -> Option<Entry> {
65         if let Some((i, _)) = self
66             .selectors
67             .iter()
68             .enumerate()
69             .find(|&(_, entry)| entry.oper == oper)
70         {
71             let entry = self.selectors.remove(i);
72             Some(entry)
73         } else {
74             None
75         }
76     }
77 
78     /// Attempts to find another thread's entry, select the operation, and wake it up.
79     #[inline]
try_select(&mut self) -> Option<Entry>80     pub(crate) fn try_select(&mut self) -> Option<Entry> {
81         if self.selectors.is_empty() {
82             None
83         } else {
84             let thread_id = current_thread_id();
85 
86             self.selectors
87                 .iter()
88                 .position(|selector| {
89                     // Does the entry belong to a different thread?
90                     selector.cx.thread_id() != thread_id
91                         && selector // Try selecting this operation.
92                             .cx
93                             .try_select(Selected::Operation(selector.oper))
94                             .is_ok()
95                         && {
96                             // Provide the packet.
97                             selector.cx.store_packet(selector.packet);
98                             // Wake the thread up.
99                             selector.cx.unpark();
100                             true
101                         }
102                 })
103                 // Remove the entry from the queue to keep it clean and improve
104                 // performance.
105                 .map(|pos| self.selectors.remove(pos))
106         }
107     }
108 
109     /// Returns `true` if there is an entry which can be selected by the current thread.
110     #[inline]
can_select(&self) -> bool111     pub(crate) fn can_select(&self) -> bool {
112         if self.selectors.is_empty() {
113             false
114         } else {
115             let thread_id = current_thread_id();
116 
117             self.selectors.iter().any(|entry| {
118                 entry.cx.thread_id() != thread_id && entry.cx.selected() == Selected::Waiting
119             })
120         }
121     }
122 
123     /// Registers an operation waiting to be ready.
124     #[inline]
watch(&mut self, oper: Operation, cx: &Context)125     pub(crate) fn watch(&mut self, oper: Operation, cx: &Context) {
126         self.observers.push(Entry {
127             oper,
128             packet: ptr::null_mut(),
129             cx: cx.clone(),
130         });
131     }
132 
133     /// Unregisters an operation waiting to be ready.
134     #[inline]
unwatch(&mut self, oper: Operation)135     pub(crate) fn unwatch(&mut self, oper: Operation) {
136         self.observers.retain(|e| e.oper != oper);
137     }
138 
139     /// Notifies all operations waiting to be ready.
140     #[inline]
notify(&mut self)141     pub(crate) fn notify(&mut self) {
142         for entry in self.observers.drain(..) {
143             if entry.cx.try_select(Selected::Operation(entry.oper)).is_ok() {
144                 entry.cx.unpark();
145             }
146         }
147     }
148 
149     /// Notifies all registered operations that the channel is disconnected.
150     #[inline]
disconnect(&mut self)151     pub(crate) fn disconnect(&mut self) {
152         for entry in self.selectors.iter() {
153             if entry.cx.try_select(Selected::Disconnected).is_ok() {
154                 // Wake the thread up.
155                 //
156                 // Here we don't remove the entry from the queue. Registered threads must
157                 // unregister from the waker by themselves. They might also want to recover the
158                 // packet value and destroy it, if necessary.
159                 entry.cx.unpark();
160             }
161         }
162 
163         self.notify();
164     }
165 }
166 
167 impl Drop for Waker {
168     #[inline]
drop(&mut self)169     fn drop(&mut self) {
170         debug_assert_eq!(self.selectors.len(), 0);
171         debug_assert_eq!(self.observers.len(), 0);
172     }
173 }
174 
175 /// A waker that can be shared among threads without locking.
176 ///
177 /// This is a simple wrapper around `Waker` that internally uses a mutex for synchronization.
178 pub(crate) struct SyncWaker {
179     /// The inner `Waker`.
180     inner: Mutex<Waker>,
181 
182     /// `true` if the waker is empty.
183     is_empty: AtomicBool,
184 }
185 
186 impl SyncWaker {
187     /// Creates a new `SyncWaker`.
188     #[inline]
new() -> Self189     pub(crate) fn new() -> Self {
190         SyncWaker {
191             inner: Mutex::new(Waker::new()),
192             is_empty: AtomicBool::new(true),
193         }
194     }
195 
196     /// Registers the current thread with an operation.
197     #[inline]
register(&self, oper: Operation, cx: &Context)198     pub(crate) fn register(&self, oper: Operation, cx: &Context) {
199         let mut inner = self.inner.lock().unwrap();
200         inner.register(oper, cx);
201         self.is_empty.store(
202             inner.selectors.is_empty() && inner.observers.is_empty(),
203             Ordering::SeqCst,
204         );
205     }
206 
207     /// Unregisters an operation previously registered by the current thread.
208     #[inline]
unregister(&self, oper: Operation) -> Option<Entry>209     pub(crate) fn unregister(&self, oper: Operation) -> Option<Entry> {
210         let mut inner = self.inner.lock().unwrap();
211         let entry = inner.unregister(oper);
212         self.is_empty.store(
213             inner.selectors.is_empty() && inner.observers.is_empty(),
214             Ordering::SeqCst,
215         );
216         entry
217     }
218 
219     /// Attempts to find one thread (not the current one), select its operation, and wake it up.
220     #[inline]
notify(&self)221     pub(crate) fn notify(&self) {
222         if !self.is_empty.load(Ordering::SeqCst) {
223             let mut inner = self.inner.lock().unwrap();
224             if !self.is_empty.load(Ordering::SeqCst) {
225                 inner.try_select();
226                 inner.notify();
227                 self.is_empty.store(
228                     inner.selectors.is_empty() && inner.observers.is_empty(),
229                     Ordering::SeqCst,
230                 );
231             }
232         }
233     }
234 
235     /// Registers an operation waiting to be ready.
236     #[inline]
watch(&self, oper: Operation, cx: &Context)237     pub(crate) fn watch(&self, oper: Operation, cx: &Context) {
238         let mut inner = self.inner.lock().unwrap();
239         inner.watch(oper, cx);
240         self.is_empty.store(
241             inner.selectors.is_empty() && inner.observers.is_empty(),
242             Ordering::SeqCst,
243         );
244     }
245 
246     /// Unregisters an operation waiting to be ready.
247     #[inline]
unwatch(&self, oper: Operation)248     pub(crate) fn unwatch(&self, oper: Operation) {
249         let mut inner = self.inner.lock().unwrap();
250         inner.unwatch(oper);
251         self.is_empty.store(
252             inner.selectors.is_empty() && inner.observers.is_empty(),
253             Ordering::SeqCst,
254         );
255     }
256 
257     /// Notifies all threads that the channel is disconnected.
258     #[inline]
disconnect(&self)259     pub(crate) fn disconnect(&self) {
260         let mut inner = self.inner.lock().unwrap();
261         inner.disconnect();
262         self.is_empty.store(
263             inner.selectors.is_empty() && inner.observers.is_empty(),
264             Ordering::SeqCst,
265         );
266     }
267 }
268 
269 impl Drop for SyncWaker {
270     #[inline]
drop(&mut self)271     fn drop(&mut self) {
272         debug_assert!(self.is_empty.load(Ordering::SeqCst));
273     }
274 }
275 
276 /// Returns the id of the current thread.
277 #[inline]
current_thread_id() -> ThreadId278 fn current_thread_id() -> ThreadId {
279     std::thread_local! {
280         /// Cached thread-local id.
281         static THREAD_ID: ThreadId = thread::current().id();
282     }
283 
284     THREAD_ID
285         .try_with(|id| *id)
286         .unwrap_or_else(|_| thread::current().id())
287 }
288