1 //! Split a single value implementing `AsyncRead + AsyncWrite` into separate
2 //! `AsyncRead` and `AsyncWrite` handles.
3 //!
4 //! To restore this read/write object from its `split::ReadHalf` and
5 //! `split::WriteHalf` use `unsplit`.
6 
7 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
8 
9 use std::fmt;
10 use std::io;
11 use std::pin::Pin;
12 use std::sync::Arc;
13 use std::sync::Mutex;
14 use std::task::{Context, Poll};
15 
16 cfg_io_util! {
17     /// The readable half of a value returned from [`split`](split()).
18     pub struct ReadHalf<T> {
19         inner: Arc<Inner<T>>,
20     }
21 
22     /// The writable half of a value returned from [`split`](split()).
23     pub struct WriteHalf<T> {
24         inner: Arc<Inner<T>>,
25     }
26 
27     /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
28     /// `AsyncRead` and `AsyncWrite` handles.
29     ///
30     /// To restore this read/write object from its `ReadHalf` and
31     /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()).
32     pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>)
33     where
34         T: AsyncRead + AsyncWrite,
35     {
36         let is_write_vectored = stream.is_write_vectored();
37 
38         let inner = Arc::new(Inner {
39             stream: Mutex::new(stream),
40             is_write_vectored,
41         });
42 
43         let rd = ReadHalf {
44             inner: inner.clone(),
45         };
46 
47         let wr = WriteHalf { inner };
48 
49         (rd, wr)
50     }
51 }
52 
53 struct Inner<T> {
54     stream: Mutex<T>,
55     is_write_vectored: bool,
56 }
57 
58 impl<T> Inner<T> {
with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R59     fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R {
60         let mut guard = self.stream.lock().unwrap();
61 
62         // safety: we do not move the stream.
63         let stream = unsafe { Pin::new_unchecked(&mut *guard) };
64 
65         f(stream)
66     }
67 }
68 
69 impl<T> ReadHalf<T> {
70     /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same
71     /// stream.
is_pair_of(&self, other: &WriteHalf<T>) -> bool72     pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool {
73         other.is_pair_of(self)
74     }
75 
76     /// Reunites with a previously split `WriteHalf`.
77     ///
78     /// # Panics
79     ///
80     /// If this `ReadHalf` and the given `WriteHalf` do not originate from the
81     /// same `split` operation this method will panic.
82     /// This can be checked ahead of time by calling [`is_pair_of()`](Self::is_pair_of).
83     #[track_caller]
unsplit(self, wr: WriteHalf<T>) -> T where T: Unpin,84     pub fn unsplit(self, wr: WriteHalf<T>) -> T
85     where
86         T: Unpin,
87     {
88         if self.is_pair_of(&wr) {
89             drop(wr);
90 
91             let inner = Arc::try_unwrap(self.inner)
92                 .ok()
93                 .expect("`Arc::try_unwrap` failed");
94 
95             inner.stream.into_inner().unwrap()
96         } else {
97             panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.")
98         }
99     }
100 }
101 
102 impl<T> WriteHalf<T> {
103     /// Checks if this `WriteHalf` and some `ReadHalf` were split from the same
104     /// stream.
is_pair_of(&self, other: &ReadHalf<T>) -> bool105     pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool {
106         Arc::ptr_eq(&self.inner, &other.inner)
107     }
108 }
109 
110 impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>111     fn poll_read(
112         self: Pin<&mut Self>,
113         cx: &mut Context<'_>,
114         buf: &mut ReadBuf<'_>,
115     ) -> Poll<io::Result<()>> {
116         self.inner.with_lock(|stream| stream.poll_read(cx, buf))
117     }
118 }
119 
120 impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>121     fn poll_write(
122         self: Pin<&mut Self>,
123         cx: &mut Context<'_>,
124         buf: &[u8],
125     ) -> Poll<Result<usize, io::Error>> {
126         self.inner.with_lock(|stream| stream.poll_write(cx, buf))
127     }
128 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>129     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
130         self.inner.with_lock(|stream| stream.poll_flush(cx))
131     }
132 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>133     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
134         self.inner.with_lock(|stream| stream.poll_shutdown(cx))
135     }
136 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>137     fn poll_write_vectored(
138         self: Pin<&mut Self>,
139         cx: &mut Context<'_>,
140         bufs: &[io::IoSlice<'_>],
141     ) -> Poll<Result<usize, io::Error>> {
142         self.inner
143             .with_lock(|stream| stream.poll_write_vectored(cx, bufs))
144     }
145 
is_write_vectored(&self) -> bool146     fn is_write_vectored(&self) -> bool {
147         self.inner.is_write_vectored
148     }
149 }
150 
151 unsafe impl<T: Send> Send for ReadHalf<T> {}
152 unsafe impl<T: Send> Send for WriteHalf<T> {}
153 unsafe impl<T: Sync> Sync for ReadHalf<T> {}
154 unsafe impl<T: Sync> Sync for WriteHalf<T> {}
155 
156 impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result157     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
158         fmt.debug_struct("split::ReadHalf").finish()
159     }
160 }
161 
162 impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result163     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
164         fmt.debug_struct("split::WriteHalf").finish()
165     }
166 }
167