1 // Copyright 2022 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use std::io; 6 use std::io::Read; 7 use std::os::unix::io::AsRawFd; 8 use std::os::unix::io::RawFd; 9 use std::os::unix::net::UnixStream; 10 use std::time::Duration; 11 12 use libc::c_void; 13 use serde::Deserialize; 14 use serde::Serialize; 15 16 use super::super::net::UnixSeqpacket; 17 use crate::descriptor::AsRawDescriptor; 18 use crate::IntoRawDescriptor; 19 use crate::RawDescriptor; 20 use crate::ReadNotifier; 21 use crate::Result; 22 23 #[derive(Copy, Clone)] 24 pub enum FramingMode { 25 Message, 26 Byte, 27 } 28 29 #[derive(Copy, Clone, PartialEq, Eq)] 30 pub enum BlockingMode { 31 Blocking, 32 Nonblocking, 33 } 34 35 impl io::Read for StreamChannel { read(&mut self, buf: &mut [u8]) -> io::Result<usize>36 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 37 self.inner_read(buf) 38 } 39 } 40 41 impl io::Read for &StreamChannel { read(&mut self, buf: &mut [u8]) -> io::Result<usize>42 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 43 self.inner_read(buf) 44 } 45 } 46 47 impl AsRawDescriptor for StreamChannel { as_raw_descriptor(&self) -> RawDescriptor48 fn as_raw_descriptor(&self) -> RawDescriptor { 49 (&self).as_raw_descriptor() 50 } 51 } 52 53 #[derive(Debug, Deserialize, Serialize)] 54 enum SocketType { 55 Message(UnixSeqpacket), 56 #[serde(with = "crate::with_as_descriptor")] 57 Byte(UnixStream), 58 } 59 60 /// An abstraction over named pipes and unix socketpairs. This abstraction can be used in a blocking 61 /// and non blocking mode. 62 #[derive(Debug, Deserialize, Serialize)] 63 pub struct StreamChannel { 64 stream: SocketType, 65 } 66 67 impl StreamChannel { set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>68 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { 69 match &mut self.stream { 70 SocketType::Byte(sock) => sock.set_nonblocking(nonblocking), 71 SocketType::Message(sock) => sock.set_nonblocking(nonblocking), 72 } 73 } 74 get_framing_mode(&self) -> FramingMode75 pub fn get_framing_mode(&self) -> FramingMode { 76 match &self.stream { 77 SocketType::Message(_) => FramingMode::Message, 78 SocketType::Byte(_) => FramingMode::Byte, 79 } 80 } 81 inner_read(&self, buf: &mut [u8]) -> io::Result<usize>82 pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> { 83 match &self.stream { 84 SocketType::Byte(sock) => (&mut &*sock).read(buf), 85 86 // On Windows, reading from SOCK_SEQPACKET with a buffer that is too small is an error, 87 // but on Linux will silently truncate unless MSG_TRUNC is passed. Here, we emulate 88 // Windows behavior on POSIX. 89 // 90 // Note that Rust translates ERROR_MORE_DATA into io::ErrorKind::Other 91 // (see sys::decode_error_kind) on Windows, so we preserve this behavior on POSIX even 92 // though one could argue ErrorKind::UnexpectedEof is a closer match to the true error. 93 SocketType::Message(sock) => { 94 // SAFETY: 95 // Safe because buf is valid, we pass buf's size to recv to bound the return 96 // length, and we check the return code. 97 let retval = unsafe { 98 // TODO(nkgold|b/152067913): Move this into the UnixSeqpacket struct as a 99 // recv_with_flags method once that struct's tests are working. 100 libc::recv( 101 sock.as_raw_descriptor(), 102 buf.as_mut_ptr() as *mut c_void, 103 buf.len(), 104 libc::MSG_TRUNC, 105 ) 106 }; 107 let receive_len = if retval < 0 { 108 Err(std::io::Error::last_os_error()) 109 } else { 110 Ok(retval) 111 }? as usize; 112 113 if receive_len > buf.len() { 114 Err(std::io::Error::new( 115 std::io::ErrorKind::Other, 116 format!( 117 "packet size {:?} encountered, but buffer was only of size {:?}", 118 receive_len, 119 buf.len() 120 ), 121 )) 122 } else { 123 Ok(receive_len) 124 } 125 } 126 } 127 } 128 129 /// Creates a cross platform stream pair. pair( blocking_mode: BlockingMode, framing_mode: FramingMode, ) -> Result<(StreamChannel, StreamChannel)>130 pub fn pair( 131 blocking_mode: BlockingMode, 132 framing_mode: FramingMode, 133 ) -> Result<(StreamChannel, StreamChannel)> { 134 let (pipe_a, pipe_b) = match framing_mode { 135 FramingMode::Byte => { 136 let (pipe_a, pipe_b) = UnixStream::pair()?; 137 (SocketType::Byte(pipe_a), SocketType::Byte(pipe_b)) 138 } 139 FramingMode::Message => { 140 let (pipe_a, pipe_b) = UnixSeqpacket::pair()?; 141 (SocketType::Message(pipe_a), SocketType::Message(pipe_b)) 142 } 143 }; 144 let mut stream_a = StreamChannel { stream: pipe_a }; 145 let mut stream_b = StreamChannel { stream: pipe_b }; 146 let is_non_blocking = blocking_mode == BlockingMode::Nonblocking; 147 stream_a.set_nonblocking(is_non_blocking)?; 148 stream_b.set_nonblocking(is_non_blocking)?; 149 Ok((stream_a, stream_b)) 150 } 151 from_unix_seqpacket(sock: UnixSeqpacket) -> StreamChannel152 pub fn from_unix_seqpacket(sock: UnixSeqpacket) -> StreamChannel { 153 StreamChannel { 154 stream: SocketType::Message(sock), 155 } 156 } 157 peek_size(&self) -> io::Result<usize>158 pub fn peek_size(&self) -> io::Result<usize> { 159 match &self.stream { 160 SocketType::Byte(_) => Err(std::io::Error::new( 161 std::io::ErrorKind::Other, 162 "Cannot check the size of streamed data", 163 )), 164 SocketType::Message(sock) => Ok(sock.next_packet_size()?), 165 } 166 } 167 set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>168 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> { 169 match &self.stream { 170 SocketType::Byte(sock) => sock.set_read_timeout(timeout), 171 SocketType::Message(sock) => sock.set_read_timeout(timeout), 172 } 173 } 174 set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>175 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> { 176 match &self.stream { 177 SocketType::Byte(sock) => sock.set_write_timeout(timeout), 178 SocketType::Message(sock) => sock.set_write_timeout(timeout), 179 } 180 } 181 182 // WARNING: Generally, multiple StreamChannel ends are not wanted. StreamChannel behavior with 183 // > 1 reader per end is not defined. try_clone(&self) -> io::Result<Self>184 pub fn try_clone(&self) -> io::Result<Self> { 185 Ok(StreamChannel { 186 stream: match &self.stream { 187 SocketType::Byte(sock) => SocketType::Byte(sock.try_clone()?), 188 SocketType::Message(sock) => SocketType::Message(sock.try_clone()?), 189 }, 190 }) 191 } 192 } 193 194 impl io::Write for StreamChannel { write(&mut self, buf: &[u8]) -> io::Result<usize>195 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 196 match &mut self.stream { 197 SocketType::Byte(sock) => sock.write(buf), 198 SocketType::Message(sock) => sock.send(buf), 199 } 200 } flush(&mut self) -> io::Result<()>201 fn flush(&mut self) -> io::Result<()> { 202 match &mut self.stream { 203 SocketType::Byte(sock) => sock.flush(), 204 SocketType::Message(_) => Ok(()), 205 } 206 } 207 } 208 209 impl io::Write for &StreamChannel { write(&mut self, buf: &[u8]) -> io::Result<usize>210 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 211 match &self.stream { 212 SocketType::Byte(sock) => (&mut &*sock).write(buf), 213 SocketType::Message(sock) => sock.send(buf), 214 } 215 } flush(&mut self) -> io::Result<()>216 fn flush(&mut self) -> io::Result<()> { 217 match &self.stream { 218 SocketType::Byte(sock) => (&mut &*sock).flush(), 219 SocketType::Message(_) => Ok(()), 220 } 221 } 222 } 223 224 impl AsRawFd for StreamChannel { as_raw_fd(&self) -> RawFd225 fn as_raw_fd(&self) -> RawFd { 226 match &self.stream { 227 SocketType::Byte(sock) => sock.as_raw_descriptor(), 228 SocketType::Message(sock) => sock.as_raw_descriptor(), 229 } 230 } 231 } 232 233 impl AsRawFd for &StreamChannel { as_raw_fd(&self) -> RawFd234 fn as_raw_fd(&self) -> RawFd { 235 self.as_raw_descriptor() 236 } 237 } 238 239 impl AsRawDescriptor for &StreamChannel { as_raw_descriptor(&self) -> RawDescriptor240 fn as_raw_descriptor(&self) -> RawDescriptor { 241 match &self.stream { 242 SocketType::Byte(sock) => sock.as_raw_descriptor(), 243 SocketType::Message(sock) => sock.as_raw_descriptor(), 244 } 245 } 246 } 247 248 impl IntoRawDescriptor for StreamChannel { into_raw_descriptor(self) -> RawFd249 fn into_raw_descriptor(self) -> RawFd { 250 match self.stream { 251 SocketType::Byte(sock) => sock.into_raw_descriptor(), 252 SocketType::Message(sock) => sock.into_raw_descriptor(), 253 } 254 } 255 } 256 257 impl ReadNotifier for StreamChannel { 258 /// Returns a RawDescriptor that can be polled for reads using PollContext. get_read_notifier(&self) -> &dyn AsRawDescriptor259 fn get_read_notifier(&self) -> &dyn AsRawDescriptor { 260 self 261 } 262 } 263 264 #[cfg(test)] 265 mod test { 266 use std::io::Read; 267 use std::io::Write; 268 269 use super::*; 270 use crate::EventContext; 271 use crate::EventToken; 272 use crate::ReadNotifier; 273 274 #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)] 275 enum Token { 276 ReceivedData, 277 } 278 279 #[test] test_non_blocking_pair_byte()280 fn test_non_blocking_pair_byte() { 281 let (mut sender, mut receiver) = 282 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap(); 283 284 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap(); 285 286 // Wait for the data to arrive. 287 let event_ctx: EventContext<Token> = 288 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 289 .unwrap(); 290 let events = event_ctx.wait().unwrap(); 291 let tokens: Vec<Token> = events 292 .iter() 293 .filter(|e| e.is_readable) 294 .map(|e| e.token) 295 .collect(); 296 assert_eq!(tokens, vec! {Token::ReceivedData}); 297 298 // Smaller than what we sent so we get multiple chunks 299 let mut recv_buffer: [u8; 4] = [0; 4]; 300 301 let mut size = receiver.read(&mut recv_buffer).unwrap(); 302 assert_eq!(size, 4); 303 assert_eq!(recv_buffer, [75, 77, 54, 82]); 304 305 size = receiver.read(&mut recv_buffer).unwrap(); 306 assert_eq!(size, 2); 307 assert_eq!(recv_buffer[0..2], [76, 65]); 308 309 // Now that we've polled for & received all data, polling again should show no events. 310 assert_eq!( 311 event_ctx 312 .wait_timeout(std::time::Duration::new(0, 0)) 313 .unwrap() 314 .len(), 315 0 316 ); 317 } 318 319 #[test] test_non_blocking_pair_message()320 fn test_non_blocking_pair_message() { 321 let (mut sender, mut receiver) = 322 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Message).unwrap(); 323 324 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap(); 325 326 // Wait for the data to arrive. 327 let event_ctx: EventContext<Token> = 328 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 329 .unwrap(); 330 let events = event_ctx.wait().unwrap(); 331 let tokens: Vec<Token> = events 332 .iter() 333 .filter(|e| e.is_readable) 334 .map(|e| e.token) 335 .collect(); 336 assert_eq!(tokens, vec! {Token::ReceivedData}); 337 338 // Unlike Byte format, Message mode panics if the buffer is smaller than the packet size; 339 // make the buffer the right size. 340 let mut recv_buffer: [u8; 6] = [0; 6]; 341 342 let size = receiver.read(&mut recv_buffer).unwrap(); 343 assert_eq!(size, 6); 344 assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]); 345 346 // Now that we've polled for & received all data, polling again should show no events. 347 assert_eq!( 348 event_ctx 349 .wait_timeout(std::time::Duration::new(0, 0)) 350 .unwrap() 351 .len(), 352 0 353 ); 354 } 355 356 #[test] test_non_blocking_pair_error_no_data()357 fn test_non_blocking_pair_error_no_data() { 358 let (mut sender, mut receiver) = 359 StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap(); 360 receiver 361 .set_nonblocking(true) 362 .expect("Failed to set receiver to nonblocking mode."); 363 364 sender.write_all(&[75, 77]).unwrap(); 365 366 // Wait for the data to arrive. 367 let event_ctx: EventContext<Token> = 368 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 369 .unwrap(); 370 let events = event_ctx.wait().unwrap(); 371 let tokens: Vec<Token> = events 372 .iter() 373 .filter(|e| e.is_readable) 374 .map(|e| e.token) 375 .collect(); 376 assert_eq!(tokens, vec! {Token::ReceivedData}); 377 378 // We only read 2 bytes, even though we requested 4 bytes. 379 let mut recv_buffer: [u8; 4] = [0; 4]; 380 let size = receiver.read(&mut recv_buffer).unwrap(); 381 assert_eq!(size, 2); 382 assert_eq!(recv_buffer, [75, 77, 00, 00]); 383 384 // Further reads should encounter an error since there is no available data and this is a 385 // non blocking pipe. 386 assert!(receiver.read(&mut recv_buffer).is_err()); 387 } 388 389 #[test] test_from_unix_seqpacket()390 fn test_from_unix_seqpacket() { 391 let (sock_sender, sock_receiver) = UnixSeqpacket::pair().unwrap(); 392 let mut sender = StreamChannel::from_unix_seqpacket(sock_sender); 393 let mut receiver = StreamChannel::from_unix_seqpacket(sock_receiver); 394 395 sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap(); 396 397 // Wait for the data to arrive. 398 let event_ctx: EventContext<Token> = 399 EventContext::build_with(&[(receiver.get_read_notifier(), Token::ReceivedData)]) 400 .unwrap(); 401 let events = event_ctx.wait().unwrap(); 402 let tokens: Vec<Token> = events 403 .iter() 404 .filter(|e| e.is_readable) 405 .map(|e| e.token) 406 .collect(); 407 assert_eq!(tokens, vec! {Token::ReceivedData}); 408 409 let mut recv_buffer: [u8; 6] = [0; 6]; 410 411 let size = receiver.read(&mut recv_buffer).unwrap(); 412 assert_eq!(size, 6); 413 assert_eq!(recv_buffer, [75, 77, 54, 82, 76, 65]); 414 415 // Now that we've polled for & received all data, polling again should show no events. 416 assert_eq!( 417 event_ctx 418 .wait_timeout(std::time::Duration::new(0, 0)) 419 .unwrap() 420 .len(), 421 0 422 ); 423 } 424 } 425