1 //! Split a single value implementing `AsyncRead + AsyncWrite` into separate 2 //! `AsyncRead` and `AsyncWrite` handles. 3 //! 4 //! To restore this read/write object from its `split::ReadHalf` and 5 //! `split::WriteHalf` use `unsplit`. 6 7 use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; 8 9 use std::fmt; 10 use std::io; 11 use std::pin::Pin; 12 use std::sync::Arc; 13 use std::sync::Mutex; 14 use std::task::{Context, Poll}; 15 16 cfg_io_util! { 17 /// The readable half of a value returned from [`split`](split()). 18 pub struct ReadHalf<T> { 19 inner: Arc<Inner<T>>, 20 } 21 22 /// The writable half of a value returned from [`split`](split()). 23 pub struct WriteHalf<T> { 24 inner: Arc<Inner<T>>, 25 } 26 27 /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate 28 /// `AsyncRead` and `AsyncWrite` handles. 29 /// 30 /// To restore this read/write object from its `ReadHalf` and 31 /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()). 32 pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) 33 where 34 T: AsyncRead + AsyncWrite, 35 { 36 let is_write_vectored = stream.is_write_vectored(); 37 38 let inner = Arc::new(Inner { 39 stream: Mutex::new(stream), 40 is_write_vectored, 41 }); 42 43 let rd = ReadHalf { 44 inner: inner.clone(), 45 }; 46 47 let wr = WriteHalf { inner }; 48 49 (rd, wr) 50 } 51 } 52 53 struct Inner<T> { 54 stream: Mutex<T>, 55 is_write_vectored: bool, 56 } 57 58 impl<T> Inner<T> { with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R59 fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R { 60 let mut guard = self.stream.lock().unwrap(); 61 62 // safety: we do not move the stream. 63 let stream = unsafe { Pin::new_unchecked(&mut *guard) }; 64 65 f(stream) 66 } 67 } 68 69 impl<T> ReadHalf<T> { 70 /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same 71 /// stream. is_pair_of(&self, other: &WriteHalf<T>) -> bool72 pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool { 73 other.is_pair_of(self) 74 } 75 76 /// Reunites with a previously split `WriteHalf`. 77 /// 78 /// # Panics 79 /// 80 /// If this `ReadHalf` and the given `WriteHalf` do not originate from the 81 /// same `split` operation this method will panic. 82 /// This can be checked ahead of time by calling [`is_pair_of()`](Self::is_pair_of). 83 #[track_caller] unsplit(self, wr: WriteHalf<T>) -> T where T: Unpin,84 pub fn unsplit(self, wr: WriteHalf<T>) -> T 85 where 86 T: Unpin, 87 { 88 if self.is_pair_of(&wr) { 89 drop(wr); 90 91 let inner = Arc::try_unwrap(self.inner) 92 .ok() 93 .expect("`Arc::try_unwrap` failed"); 94 95 inner.stream.into_inner().unwrap() 96 } else { 97 panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.") 98 } 99 } 100 } 101 102 impl<T> WriteHalf<T> { 103 /// Checks if this `WriteHalf` and some `ReadHalf` were split from the same 104 /// stream. is_pair_of(&self, other: &ReadHalf<T>) -> bool105 pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool { 106 Arc::ptr_eq(&self.inner, &other.inner) 107 } 108 } 109 110 impl<T: AsyncRead> AsyncRead for ReadHalf<T> { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>111 fn poll_read( 112 self: Pin<&mut Self>, 113 cx: &mut Context<'_>, 114 buf: &mut ReadBuf<'_>, 115 ) -> Poll<io::Result<()>> { 116 self.inner.with_lock(|stream| stream.poll_read(cx, buf)) 117 } 118 } 119 120 impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>121 fn poll_write( 122 self: Pin<&mut Self>, 123 cx: &mut Context<'_>, 124 buf: &[u8], 125 ) -> Poll<Result<usize, io::Error>> { 126 self.inner.with_lock(|stream| stream.poll_write(cx, buf)) 127 } 128 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>129 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 130 self.inner.with_lock(|stream| stream.poll_flush(cx)) 131 } 132 poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>133 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 134 self.inner.with_lock(|stream| stream.poll_shutdown(cx)) 135 } 136 poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>137 fn poll_write_vectored( 138 self: Pin<&mut Self>, 139 cx: &mut Context<'_>, 140 bufs: &[io::IoSlice<'_>], 141 ) -> Poll<Result<usize, io::Error>> { 142 self.inner 143 .with_lock(|stream| stream.poll_write_vectored(cx, bufs)) 144 } 145 is_write_vectored(&self) -> bool146 fn is_write_vectored(&self) -> bool { 147 self.inner.is_write_vectored 148 } 149 } 150 151 unsafe impl<T: Send> Send for ReadHalf<T> {} 152 unsafe impl<T: Send> Send for WriteHalf<T> {} 153 unsafe impl<T: Sync> Sync for ReadHalf<T> {} 154 unsafe impl<T: Sync> Sync for WriteHalf<T> {} 155 156 impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> { fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result157 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 158 fmt.debug_struct("split::ReadHalf").finish() 159 } 160 } 161 162 impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> { fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result163 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 164 fmt.debug_struct("split::WriteHalf").finish() 165 } 166 } 167