xref: /aosp_15_r20/external/crosvm/base/src/sys/windows/wait.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::cmp::min;
6 use std::collections::HashMap;
7 use std::os::windows::io::RawHandle;
8 use std::sync::Arc;
9 use std::time::Duration;
10 
11 use smallvec::SmallVec;
12 use sync::Mutex;
13 use winapi::shared::minwindef::DWORD;
14 use winapi::shared::minwindef::FALSE;
15 use winapi::shared::winerror::ERROR_INVALID_PARAMETER;
16 use winapi::shared::winerror::WAIT_TIMEOUT;
17 use winapi::um::synchapi::WaitForMultipleObjects;
18 use winapi::um::winbase::WAIT_OBJECT_0;
19 
20 use super::errno_result;
21 use super::Error;
22 use super::EventTrigger;
23 use super::Result;
24 use crate::descriptor::AsRawDescriptor;
25 use crate::descriptor::Descriptor;
26 use crate::error;
27 use crate::Event;
28 use crate::EventToken;
29 use crate::EventType;
30 use crate::RawDescriptor;
31 use crate::TriggeredEvent;
32 use crate::WaitContext;
33 
34 // MAXIMUM_WAIT_OBJECTS = 64
35 pub const MAXIMUM_WAIT_OBJECTS: usize = winapi::um::winnt::MAXIMUM_WAIT_OBJECTS as usize;
36 
37 // TODO(145170451) rizhang: implement round robin if event size is greater than 64
38 
39 pub trait WaitContextExt {
40     /// Removes all handles registered in the WaitContext.
clear(&self) -> Result<()>41     fn clear(&self) -> Result<()>;
42 }
43 
44 impl<T: EventToken> WaitContextExt for WaitContext<T> {
clear(&self) -> Result<()>45     fn clear(&self) -> Result<()> {
46         self.0.clear()
47     }
48 }
49 
50 struct RegisteredHandles<T: EventToken> {
51     triggers: HashMap<Descriptor, T>,
52     raw_handles: Vec<Descriptor>,
53 }
54 
55 pub struct EventContext<T: EventToken> {
56     registered_handles: Arc<Mutex<RegisteredHandles<T>>>,
57 
58     // An internally-used event to signify that the list of handles has been modified
59     // mid-wait. This is to solve for instances where Thread A has started waiting and
60     // Thread B adds an event trigger, which needs to notify Thread A a change has been
61     // made.
62     handles_modified_event: Event,
63 }
64 
65 impl<T: EventToken> EventContext<T> {
new() -> Result<EventContext<T>>66     pub fn new() -> Result<EventContext<T>> {
67         let new = EventContext {
68             registered_handles: Arc::new(Mutex::new(RegisteredHandles {
69                 triggers: HashMap::new(),
70                 raw_handles: Vec::new(),
71             })),
72             handles_modified_event: Event::new().unwrap(),
73         };
74         // The handles-modified event will be everpresent on the raw_handles to be waited
75         // upon to ensure the wait stops and we update it any time the handles list is
76         // modified.
77         new.registered_handles
78             .lock()
79             .raw_handles
80             .push(Descriptor(new.handles_modified_event.as_raw_descriptor()));
81         Ok(new)
82     }
83 
84     /// Creates a new EventContext with the the associated triggers.
build_with(triggers: &[EventTrigger<T>]) -> Result<EventContext<T>>85     pub fn build_with(triggers: &[EventTrigger<T>]) -> Result<EventContext<T>> {
86         let ctx = EventContext::new()?;
87         ctx.add_many(triggers)?;
88         Ok(ctx)
89     }
90 
91     /// Adds a trigger to the EventContext.
add(&self, trigger: EventTrigger<T>) -> Result<()>92     pub fn add(&self, trigger: EventTrigger<T>) -> Result<()> {
93         self.add_for_event_impl(trigger, EventType::Read)
94     }
95 
96     /// Adds a trigger to the EventContext.
add_many(&self, triggers: &[EventTrigger<T>]) -> Result<()>97     pub fn add_many(&self, triggers: &[EventTrigger<T>]) -> Result<()> {
98         for trigger in triggers {
99             self.add(trigger.clone())?
100         }
101         Ok(())
102     }
103 
add_for_event( &self, descriptor: &dyn AsRawDescriptor, event_type: EventType, token: T, ) -> Result<()>104     pub fn add_for_event(
105         &self,
106         descriptor: &dyn AsRawDescriptor,
107         event_type: EventType,
108         token: T,
109     ) -> Result<()> {
110         self.add_for_event_impl(EventTrigger::from(descriptor, token), event_type)
111     }
112 
add_for_event_impl(&self, trigger: EventTrigger<T>, _event_type: EventType) -> Result<()>113     fn add_for_event_impl(&self, trigger: EventTrigger<T>, _event_type: EventType) -> Result<()> {
114         let mut registered_handles_locked = self.registered_handles.lock();
115         if registered_handles_locked
116             .triggers
117             .contains_key(&Descriptor(trigger.event))
118         {
119             // If this handle is already added, silently succeed with a noop
120             return Ok(());
121         }
122         registered_handles_locked
123             .triggers
124             .insert(Descriptor(trigger.event), trigger.token);
125         registered_handles_locked
126             .raw_handles
127             .push(Descriptor(trigger.event));
128         // Windows doesn't support watching for specific types of events. Just treat this
129         // like a normal add and do nothing with event_type
130         self.handles_modified_event.signal()
131     }
132 
modify( &self, descriptor: &dyn AsRawDescriptor, _event_type: EventType, token: T, ) -> Result<()>133     pub fn modify(
134         &self,
135         descriptor: &dyn AsRawDescriptor,
136         _event_type: EventType,
137         token: T,
138     ) -> Result<()> {
139         let trigger = EventTrigger::from(descriptor, token);
140 
141         let mut registered_handles_locked = self.registered_handles.lock();
142         if let std::collections::hash_map::Entry::Occupied(mut e) = registered_handles_locked
143             .triggers
144             .entry(Descriptor(trigger.event))
145         {
146             e.insert(trigger.token);
147         }
148         // Windows doesn't support watching for specific types of events. Ignore the event_type
149         // and just modify the token.
150         self.handles_modified_event.signal()
151     }
152 
delete(&self, event_handle: &dyn AsRawDescriptor) -> Result<()>153     pub fn delete(&self, event_handle: &dyn AsRawDescriptor) -> Result<()> {
154         let mut registered_handles_locked = self.registered_handles.lock();
155         let result = registered_handles_locked
156             .triggers
157             .remove(&Descriptor(event_handle.as_raw_descriptor()));
158         if result.is_none() {
159             // this handle was not registered in the first place. Silently succeed with a noop
160             return Ok(());
161         }
162         let index = registered_handles_locked
163             .raw_handles
164             .iter()
165             .position(|item| item == &Descriptor(event_handle.as_raw_descriptor()))
166             .unwrap();
167         registered_handles_locked.raw_handles.remove(index);
168         self.handles_modified_event.signal()
169     }
170 
clear(&self) -> Result<()>171     pub fn clear(&self) -> Result<()> {
172         let mut registered_handles_locked = self.registered_handles.lock();
173         registered_handles_locked.triggers.clear();
174         registered_handles_locked.raw_handles.clear();
175 
176         registered_handles_locked
177             .raw_handles
178             .push(Descriptor(self.handles_modified_event.as_raw_descriptor()));
179         self.handles_modified_event.signal()
180     }
181 
182     /// Waits for one or more of the registered triggers to become signaled.
wait(&self) -> Result<SmallVec<[TriggeredEvent<T>; 16]>>183     pub fn wait(&self) -> Result<SmallVec<[TriggeredEvent<T>; 16]>> {
184         self.wait_timeout(Duration::new(i64::MAX as u64, 0))
185     }
186 
wait_timeout(&self, timeout: Duration) -> Result<SmallVec<[TriggeredEvent<T>; 16]>>187     pub fn wait_timeout(&self, timeout: Duration) -> Result<SmallVec<[TriggeredEvent<T>; 16]>> {
188         let raw_handles_list: Vec<RawHandle> = self
189             .registered_handles
190             .lock()
191             .raw_handles
192             .clone()
193             .into_iter()
194             .map(|handle| handle.0)
195             .collect();
196         if raw_handles_list.len() == 1 {
197             // Disallow calls with no handles to wait on. Do not include the handles_modified_event
198             // which always populates the list.
199             return Err(Error::new(ERROR_INVALID_PARAMETER));
200         }
201         // SAFETY: raw handles array is expected to contain valid handles and the return value of
202         // the function is checked.
203         let result = unsafe {
204             WaitForMultipleObjects(
205                 raw_handles_list.len() as DWORD,
206                 raw_handles_list.as_ptr(),
207                 FALSE, // return when one event is signaled
208                 timeout.as_millis() as DWORD,
209             )
210         };
211         let handles_len = min(MAXIMUM_WAIT_OBJECTS, raw_handles_list.len());
212 
213         const MAXIMUM_WAIT_OBJECTS_U32: u32 = MAXIMUM_WAIT_OBJECTS as u32;
214         match result {
215             WAIT_OBJECT_0..=MAXIMUM_WAIT_OBJECTS_U32 => {
216                 let mut event_index = (result - WAIT_OBJECT_0) as usize;
217                 if event_index >= handles_len {
218                     // This is not a valid index and should return an error. This case should not be
219                     // possible and will likely not return a meaningful system
220                     // error code, but is still an invalid case.
221                     error!("Wait returned index out of range");
222                     return errno_result();
223                 }
224                 if event_index == 0 {
225                     // The handles list has been modified and triggered the wait, try again with the
226                     // updated handles list. Note it is possible the list was
227                     // modified again after the wait which will trigger the
228                     // handles_modified_event again, but that will only err towards the safe side
229                     // of recursing an extra time.
230                     let _ = self.handles_modified_event.wait();
231                     return self.wait_timeout(timeout);
232                 }
233 
234                 let mut events_to_return = SmallVec::<[TriggeredEvent<T>; 16]>::new();
235                 // Multiple events may be triggered at once, but WaitForMultipleObjects will only
236                 // return one. Once it returns, loop through the remaining triggers
237                 // checking each to ensure they haven't also been triggered.
238                 let mut handles_offset: usize = 0;
239                 loop {
240                     let event_to_return = raw_handles_list[event_index + handles_offset];
241                     events_to_return.push(TriggeredEvent {
242                         token: T::from_raw_token(
243                             self.registered_handles
244                                 .lock()
245                                 .triggers
246                                 .get(&Descriptor(event_to_return))
247                                 .unwrap()
248                                 .as_raw_token(),
249                         ),
250                         // In Windows, events aren't associated with read/writability, so for cross-
251                         // compatability, associate with both.
252                         is_readable: true,
253                         is_writable: true,
254                         is_hungup: false,
255                     });
256 
257                     handles_offset += event_index + 1;
258                     if handles_offset >= handles_len {
259                         break;
260                     }
261                     event_index = (
262                         // SAFETY: raw handles array is expected to contain valid handles and the
263                         // return value of the function is checked.
264                         unsafe {
265                             WaitForMultipleObjects(
266                                 (raw_handles_list.len() - handles_offset) as DWORD,
267                                 raw_handles_list[handles_offset..].as_ptr(),
268                                 FALSE, // return when one event is signaled
269                                 0,     /* instantaneous timeout */
270                             )
271                         } - WAIT_OBJECT_0
272                     ) as usize;
273 
274                     if event_index >= (handles_len - handles_offset) {
275                         // This indicates a failure condition, as return values greater than the
276                         // length of the provided array are reserved for
277                         // failures.
278                         break;
279                     }
280                 }
281 
282                 Ok(events_to_return)
283             }
284             WAIT_TIMEOUT => Ok(Default::default()),
285             // Invalid cases. This is most likely an WAIT_FAILED, but anything not matched by the
286             // above is an error case.
287             _ => errno_result(),
288         }
289     }
290 }
291 
292 impl<T: EventToken> AsRawDescriptor for EventContext<T> {
as_raw_descriptor(&self) -> RawDescriptor293     fn as_raw_descriptor(&self) -> RawDescriptor {
294         self.handles_modified_event.as_raw_descriptor()
295     }
296 }
297 
298 #[cfg(test)]
299 mod tests {
300     use super::*;
301 
302     #[test]
303     #[should_panic]
error_on_empty_context_wait()304     fn error_on_empty_context_wait() {
305         let ctx: EventContext<u32> = EventContext::new().unwrap();
306         let dur = Duration::from_millis(10);
307         ctx.wait_timeout(dur).unwrap();
308     }
309 }
310