xref: /aosp_15_r20/external/crosvm/base/src/sys/unix/stream_channel.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::io;
6 use std::io::Read;
7 use std::os::unix::io::AsRawFd;
8 use std::os::unix::io::RawFd;
9 use std::os::unix::net::UnixStream;
10 use std::time::Duration;
11 
12 use libc::c_void;
13 use serde::Deserialize;
14 use serde::Serialize;
15 
16 use super::super::net::UnixSeqpacket;
17 use crate::descriptor::AsRawDescriptor;
18 use crate::IntoRawDescriptor;
19 use crate::RawDescriptor;
20 use crate::ReadNotifier;
21 use crate::Result;
22 
23 #[derive(Copy, Clone)]
24 pub enum FramingMode {
25     Message,
26     Byte,
27 }
28 
29 #[derive(Copy, Clone, PartialEq, Eq)]
30 pub enum BlockingMode {
31     Blocking,
32     Nonblocking,
33 }
34 
35 impl io::Read for StreamChannel {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>36     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
37         self.inner_read(buf)
38     }
39 }
40 
41 impl io::Read for &StreamChannel {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>42     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
43         self.inner_read(buf)
44     }
45 }
46 
47 impl AsRawDescriptor for StreamChannel {
as_raw_descriptor(&self) -> RawDescriptor48     fn as_raw_descriptor(&self) -> RawDescriptor {
49         (&self).as_raw_descriptor()
50     }
51 }
52 
53 #[derive(Debug, Deserialize, Serialize)]
54 enum SocketType {
55     Message(UnixSeqpacket),
56     #[serde(with = "crate::with_as_descriptor")]
57     Byte(UnixStream),
58 }
59 
60 /// An abstraction over named pipes and unix socketpairs. This abstraction can be used in a blocking
61 /// and non blocking mode.
62 #[derive(Debug, Deserialize, Serialize)]
63 pub struct StreamChannel {
64     stream: SocketType,
65 }
66 
67 impl StreamChannel {
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>68     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
69         match &mut self.stream {
70             SocketType::Byte(sock) => sock.set_nonblocking(nonblocking),
71             SocketType::Message(sock) => sock.set_nonblocking(nonblocking),
72         }
73     }
74 
get_framing_mode(&self) -> FramingMode75     pub fn get_framing_mode(&self) -> FramingMode {
76         match &self.stream {
77             SocketType::Message(_) => FramingMode::Message,
78             SocketType::Byte(_) => FramingMode::Byte,
79         }
80     }
81 
inner_read(&self, buf: &mut [u8]) -> io::Result<usize>82     pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> {
83         match &self.stream {
84             SocketType::Byte(sock) => (&mut &*sock).read(buf),
85 
86             // On Windows, reading from SOCK_SEQPACKET with a buffer that is too small is an error,
87             // but on Linux will silently truncate unless MSG_TRUNC is passed. Here, we emulate
88             // Windows behavior on POSIX.
89             //
90             // Note that Rust translates ERROR_MORE_DATA into io::ErrorKind::Other
91             // (see sys::decode_error_kind) on Windows, so we preserve this behavior on POSIX even
92             // though one could argue ErrorKind::UnexpectedEof is a closer match to the true error.
93             SocketType::Message(sock) => {
94                 // SAFETY:
95                 // Safe because buf is valid, we pass buf's size to recv to bound the return
96                 // length, and we check the return code.
97                 let retval = unsafe {
98                     // TODO(nkgold|b/152067913): Move this into the UnixSeqpacket struct as a
99                     // recv_with_flags method once that struct's tests are working.
100                     libc::recv(
101                         sock.as_raw_descriptor(),
102                         buf.as_mut_ptr() as *mut c_void,
103                         buf.len(),
104                         libc::MSG_TRUNC,
105                     )
106                 };
107                 let receive_len = if retval < 0 {
108                     Err(std::io::Error::last_os_error())
109                 } else {
110                     Ok(retval)
111                 }? as usize;
112 
113                 if receive_len > buf.len() {
114                     Err(std::io::Error::new(
115                         std::io::ErrorKind::Other,
116                         format!(
117                             "packet size {:?} encountered, but buffer was only of size {:?}",
118                             receive_len,
119                             buf.len()
120                         ),
121                     ))
122                 } else {
123                     Ok(receive_len)
124                 }
125             }
126         }
127     }
128 
129     /// Creates a cross platform stream pair.
pair( blocking_mode: BlockingMode, framing_mode: FramingMode, ) -> Result<(StreamChannel, StreamChannel)>130     pub fn pair(
131         blocking_mode: BlockingMode,
132         framing_mode: FramingMode,
133     ) -> Result<(StreamChannel, StreamChannel)> {
134         let (pipe_a, pipe_b) = match framing_mode {
135             FramingMode::Byte => {
136                 let (pipe_a, pipe_b) = UnixStream::pair()?;
137                 (SocketType::Byte(pipe_a), SocketType::Byte(pipe_b))
138             }
139             FramingMode::Message => {
140                 let (pipe_a, pipe_b) = UnixSeqpacket::pair()?;
141                 (SocketType::Message(pipe_a), SocketType::Message(pipe_b))
142             }
143         };
144         let mut stream_a = StreamChannel { stream: pipe_a };
145         let mut stream_b = StreamChannel { stream: pipe_b };
146         let is_non_blocking = blocking_mode == BlockingMode::Nonblocking;
147         stream_a.set_nonblocking(is_non_blocking)?;
148         stream_b.set_nonblocking(is_non_blocking)?;
149         Ok((stream_a, stream_b))
150     }
151 
from_unix_seqpacket(sock: UnixSeqpacket) -> StreamChannel152     pub fn from_unix_seqpacket(sock: UnixSeqpacket) -> StreamChannel {
153         StreamChannel {
154             stream: SocketType::Message(sock),
155         }
156     }
157 
peek_size(&self) -> io::Result<usize>158     pub fn peek_size(&self) -> io::Result<usize> {
159         match &self.stream {
160             SocketType::Byte(_) => Err(std::io::Error::new(
161                 std::io::ErrorKind::Other,
162                 "Cannot check the size of streamed data",
163             )),
164             SocketType::Message(sock) => Ok(sock.next_packet_size()?),
165         }
166     }
167 
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>168     pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
169         match &self.stream {
170             SocketType::Byte(sock) => sock.set_read_timeout(timeout),
171             SocketType::Message(sock) => sock.set_read_timeout(timeout),
172         }
173     }
174 
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>175     pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
176         match &self.stream {
177             SocketType::Byte(sock) => sock.set_write_timeout(timeout),
178             SocketType::Message(sock) => sock.set_write_timeout(timeout),
179         }
180     }
181 
182     // WARNING: Generally, multiple StreamChannel ends are not wanted. StreamChannel behavior with
183     // > 1 reader per end is not defined.
try_clone(&self) -> io::Result<Self>184     pub fn try_clone(&self) -> io::Result<Self> {
185         Ok(StreamChannel {
186             stream: match &self.stream {
187                 SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?),
188                 SocketType::Message(sock) => SocketType::Message(sock.try_clone()?),
189             },
190         })
191     }
192 }
193 
194 impl io::Write for StreamChannel {
write(&mut self, buf: &[u8]) -> io::Result<usize>195     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
196         match &mut self.stream {
197             SocketType::Byte(sock) => sock.write(buf),
198             SocketType::Message(sock) => sock.send(buf),
199         }
200     }
flush(&mut self) -> io::Result<()>201     fn flush(&mut self) -> io::Result<()> {
202         match &mut self.stream {
203             SocketType::Byte(sock) => sock.flush(),
204             SocketType::Message(_) => Ok(()),
205         }
206     }
207 }
208 
209 impl io::Write for &StreamChannel {
write(&mut self, buf: &[u8]) -> io::Result<usize>210     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
211         match &self.stream {
212             SocketType::Byte(sock) => (&mut &*sock).write(buf),
213             SocketType::Message(sock) => sock.send(buf),
214         }
215     }
flush(&mut self) -> io::Result<()>216     fn flush(&mut self) -> io::Result<()> {
217         match &self.stream {
218             SocketType::Byte(sock) => (&mut &*sock).flush(),
219             SocketType::Message(_) => Ok(()),
220         }
221     }
222 }
223 
224 impl AsRawFd for StreamChannel {
as_raw_fd(&self) -> RawFd225     fn as_raw_fd(&self) -> RawFd {
226         match &self.stream {
227             SocketType::Byte(sock) => sock.as_raw_descriptor(),
228             SocketType::Message(sock) => sock.as_raw_descriptor(),
229         }
230     }
231 }
232 
233 impl AsRawFd for &StreamChannel {
as_raw_fd(&self) -> RawFd234     fn as_raw_fd(&self) -> RawFd {
235         self.as_raw_descriptor()
236     }
237 }
238 
239 impl AsRawDescriptor for &StreamChannel {
as_raw_descriptor(&self) -> RawDescriptor240     fn as_raw_descriptor(&self) -> RawDescriptor {
241         match &self.stream {
242             SocketType::Byte(sock) => sock.as_raw_descriptor(),
243             SocketType::Message(sock) => sock.as_raw_descriptor(),
244         }
245     }
246 }
247 
248 impl IntoRawDescriptor for StreamChannel {
into_raw_descriptor(self) -> RawFd249     fn into_raw_descriptor(self) -> RawFd {
250         match self.stream {
251             SocketType::Byte(sock) => sock.into_raw_descriptor(),
252             SocketType::Message(sock) => sock.into_raw_descriptor(),
253         }
254     }
255 }
256 
257 impl ReadNotifier for StreamChannel {
258     /// Returns a RawDescriptor that can be polled for reads using PollContext.
get_read_notifier(&self) -> &dyn AsRawDescriptor259     fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
260         self
261     }
262 }
263 
264 #[cfg(test)]
265 mod test {
266     use std::io::Read;
267     use std::io::Write;
268 
269     use super::*;
270     use crate::EventContext;
271     use crate::EventToken;
272     use crate::ReadNotifier;
273 
274     #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
275     enum Token {
276         ReceivedData,
277     }
278 
279     #[test]
test_non_blocking_pair_byte()280     fn test_non_blocking_pair_byte() {
281         let (mut sender, mut receiver) =
282             StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
283 
284         sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
285 
286         // Wait for the data to arrive.
287         let event_ctx: EventContext<Token> =
288             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
289                 .unwrap();
290         let events = event_ctx.wait().unwrap();
291         let tokens: Vec<Token> = events
292             .iter()
293             .filter(|e| e.is_readable)
294             .map(|e| e.token)
295             .collect();
296         assert_eq!(tokens, vec! {Token::ReceivedData});
297 
298         // Smaller than what we sent so we get multiple chunks
299         let mut recv_buffer: [u8; 4] = [0; 4];
300 
301         let mut size = receiver.read(&mut recv_buffer).unwrap();
302         assert_eq!(size, 4);
303         assert_eq!(recv_buffer, [75, 77, 54, 82]);
304 
305         size = receiver.read(&mut recv_buffer).unwrap();
306         assert_eq!(size, 2);
307         assert_eq!(recv_buffer[0..2], [76, 65]);
308 
309         // Now that we've polled for & received all data, polling again should show no events.
310         assert_eq!(
311             event_ctx
312                 .wait_timeout(std::time::Duration::new(0, 0))
313                 .unwrap()
314                 .len(),
315             0
316         );
317     }
318 
319     #[test]
test_non_blocking_pair_message()320     fn test_non_blocking_pair_message() {
321         let (mut sender, mut receiver) =
322             StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap();
323 
324         sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
325 
326         // Wait for the data to arrive.
327         let event_ctx: EventContext<Token> =
328             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
329                 .unwrap();
330         let events = event_ctx.wait().unwrap();
331         let tokens: Vec<Token> = events
332             .iter()
333             .filter(|e| e.is_readable)
334             .map(|e| e.token)
335             .collect();
336         assert_eq!(tokens, vec! {Token::ReceivedData});
337 
338         // Unlike Byte format, Message mode panics if the buffer is smaller than the packet size;
339         // make the buffer the right size.
340         let mut recv_buffer: [u8; 6] = [0; 6];
341 
342         let size = receiver.read(&mut recv_buffer).unwrap();
343         assert_eq!(size, 6);
344         assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]);
345 
346         // Now that we've polled for & received all data, polling again should show no events.
347         assert_eq!(
348             event_ctx
349                 .wait_timeout(std::time::Duration::new(0, 0))
350                 .unwrap()
351                 .len(),
352             0
353         );
354     }
355 
356     #[test]
test_non_blocking_pair_error_no_data()357     fn test_non_blocking_pair_error_no_data() {
358         let (mut sender, mut receiver) =
359             StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
360         receiver
361             .set_nonblocking(true)
362             .expect("Failed to set receiver to nonblocking mode.");
363 
364         sender.write_all(&[75, 77]).unwrap();
365 
366         // Wait for the data to arrive.
367         let event_ctx: EventContext<Token> =
368             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
369                 .unwrap();
370         let events = event_ctx.wait().unwrap();
371         let tokens: Vec<Token> = events
372             .iter()
373             .filter(|e| e.is_readable)
374             .map(|e| e.token)
375             .collect();
376         assert_eq!(tokens, vec! {Token::ReceivedData});
377 
378         // We only read 2 bytes, even though we requested 4 bytes.
379         let mut recv_buffer: [u8; 4] = [0; 4];
380         let size = receiver.read(&mut recv_buffer).unwrap();
381         assert_eq!(size, 2);
382         assert_eq!(recv_buffer, [75, 77, 00, 00]);
383 
384         // Further reads should encounter an error since there is no available data and this is a
385         // non blocking pipe.
386         assert!(receiver.read(&mut recv_buffer).is_err());
387     }
388 
389     #[test]
test_from_unix_seqpacket()390     fn test_from_unix_seqpacket() {
391         let (sock_sender, sock_receiver) = UnixSeqpacket::pair().unwrap();
392         let mut sender = StreamChannel::from_unix_seqpacket(sock_sender);
393         let mut receiver = StreamChannel::from_unix_seqpacket(sock_receiver);
394 
395         sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
396 
397         // Wait for the data to arrive.
398         let event_ctx: EventContext<Token> =
399             EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)])
400                 .unwrap();
401         let events = event_ctx.wait().unwrap();
402         let tokens: Vec<Token> = events
403             .iter()
404             .filter(|e| e.is_readable)
405             .map(|e| e.token)
406             .collect();
407         assert_eq!(tokens, vec! {Token::ReceivedData});
408 
409         let mut recv_buffer: [u8; 6] = [0; 6];
410 
411         let size = receiver.read(&mut recv_buffer).unwrap();
412         assert_eq!(size, 6);
413         assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]);
414 
415         // Now that we've polled for & received all data, polling again should show no events.
416         assert_eq!(
417             event_ctx
418                 .wait_timeout(std::time::Duration::new(0, 0))
419                 .unwrap()
420                 .len(),
421             0
422         );
423     }
424 }
425