1 //! Compatibility between the `tokio::io` and `futures-io` versions of the
2 //! `AsyncRead` and `AsyncWrite` traits.
3 use pin_project_lite::pin_project;
4 use std::io;
5 use std::pin::Pin;
6 use std::task::{ready, Context, Poll};
7 
8 pin_project! {
9     /// A compatibility layer that allows conversion between the
10     /// `tokio::io` and `futures-io` `AsyncRead` and `AsyncWrite` traits.
11     #[derive(Copy, Clone, Debug)]
12     pub struct Compat<T> {
13         #[pin]
14         inner: T,
15         seek_pos: Option<io::SeekFrom>,
16     }
17 }
18 
19 /// Extension trait that allows converting a type implementing
20 /// `futures_io::AsyncRead` to implement `tokio::io::AsyncRead`.
21 pub trait FuturesAsyncReadCompatExt: futures_io::AsyncRead {
22     /// Wraps `self` with a compatibility layer that implements
23     /// `tokio_io::AsyncRead`.
compat(self) -> Compat<Self> where Self: Sized,24     fn compat(self) -> Compat<Self>
25     where
26         Self: Sized,
27     {
28         Compat::new(self)
29     }
30 }
31 
32 impl<T: futures_io::AsyncRead> FuturesAsyncReadCompatExt for T {}
33 
34 /// Extension trait that allows converting a type implementing
35 /// `futures_io::AsyncWrite` to implement `tokio::io::AsyncWrite`.
36 pub trait FuturesAsyncWriteCompatExt: futures_io::AsyncWrite {
37     /// Wraps `self` with a compatibility layer that implements
38     /// `tokio::io::AsyncWrite`.
compat_write(self) -> Compat<Self> where Self: Sized,39     fn compat_write(self) -> Compat<Self>
40     where
41         Self: Sized,
42     {
43         Compat::new(self)
44     }
45 }
46 
47 impl<T: futures_io::AsyncWrite> FuturesAsyncWriteCompatExt for T {}
48 
49 /// Extension trait that allows converting a type implementing
50 /// `tokio::io::AsyncRead` to implement `futures_io::AsyncRead`.
51 pub trait TokioAsyncReadCompatExt: tokio::io::AsyncRead {
52     /// Wraps `self` with a compatibility layer that implements
53     /// `futures_io::AsyncRead`.
compat(self) -> Compat<Self> where Self: Sized,54     fn compat(self) -> Compat<Self>
55     where
56         Self: Sized,
57     {
58         Compat::new(self)
59     }
60 }
61 
62 impl<T: tokio::io::AsyncRead> TokioAsyncReadCompatExt for T {}
63 
64 /// Extension trait that allows converting a type implementing
65 /// `tokio::io::AsyncWrite` to implement `futures_io::AsyncWrite`.
66 pub trait TokioAsyncWriteCompatExt: tokio::io::AsyncWrite {
67     /// Wraps `self` with a compatibility layer that implements
68     /// `futures_io::AsyncWrite`.
compat_write(self) -> Compat<Self> where Self: Sized,69     fn compat_write(self) -> Compat<Self>
70     where
71         Self: Sized,
72     {
73         Compat::new(self)
74     }
75 }
76 
77 impl<T: tokio::io::AsyncWrite> TokioAsyncWriteCompatExt for T {}
78 
79 // === impl Compat ===
80 
81 impl<T> Compat<T> {
new(inner: T) -> Self82     fn new(inner: T) -> Self {
83         Self {
84             inner,
85             seek_pos: None,
86         }
87     }
88 
89     /// Get a reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
90     /// contained within.
get_ref(&self) -> &T91     pub fn get_ref(&self) -> &T {
92         &self.inner
93     }
94 
95     /// Get a mutable reference to the `Future`, `Stream`, `AsyncRead`, or `AsyncWrite` object
96     /// contained within.
get_mut(&mut self) -> &mut T97     pub fn get_mut(&mut self) -> &mut T {
98         &mut self.inner
99     }
100 
101     /// Returns the wrapped item.
into_inner(self) -> T102     pub fn into_inner(self) -> T {
103         self.inner
104     }
105 }
106 
107 impl<T> tokio::io::AsyncRead for Compat<T>
108 where
109     T: futures_io::AsyncRead,
110 {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll<io::Result<()>>111     fn poll_read(
112         self: Pin<&mut Self>,
113         cx: &mut Context<'_>,
114         buf: &mut tokio::io::ReadBuf<'_>,
115     ) -> Poll<io::Result<()>> {
116         // We can't trust the inner type to not peak at the bytes,
117         // so we must defensively initialize the buffer.
118         let slice = buf.initialize_unfilled();
119         let n = ready!(futures_io::AsyncRead::poll_read(
120             self.project().inner,
121             cx,
122             slice
123         ))?;
124         buf.advance(n);
125         Poll::Ready(Ok(()))
126     }
127 }
128 
129 impl<T> futures_io::AsyncRead for Compat<T>
130 where
131     T: tokio::io::AsyncRead,
132 {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, slice: &mut [u8], ) -> Poll<io::Result<usize>>133     fn poll_read(
134         self: Pin<&mut Self>,
135         cx: &mut Context<'_>,
136         slice: &mut [u8],
137     ) -> Poll<io::Result<usize>> {
138         let mut buf = tokio::io::ReadBuf::new(slice);
139         ready!(tokio::io::AsyncRead::poll_read(
140             self.project().inner,
141             cx,
142             &mut buf
143         ))?;
144         Poll::Ready(Ok(buf.filled().len()))
145     }
146 }
147 
148 impl<T> tokio::io::AsyncBufRead for Compat<T>
149 where
150     T: futures_io::AsyncBufRead,
151 {
poll_fill_buf<'a>( self: Pin<&'a mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<&'a [u8]>>152     fn poll_fill_buf<'a>(
153         self: Pin<&'a mut Self>,
154         cx: &mut Context<'_>,
155     ) -> Poll<io::Result<&'a [u8]>> {
156         futures_io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
157     }
158 
consume(self: Pin<&mut Self>, amt: usize)159     fn consume(self: Pin<&mut Self>, amt: usize) {
160         futures_io::AsyncBufRead::consume(self.project().inner, amt)
161     }
162 }
163 
164 impl<T> futures_io::AsyncBufRead for Compat<T>
165 where
166     T: tokio::io::AsyncBufRead,
167 {
poll_fill_buf<'a>( self: Pin<&'a mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<&'a [u8]>>168     fn poll_fill_buf<'a>(
169         self: Pin<&'a mut Self>,
170         cx: &mut Context<'_>,
171     ) -> Poll<io::Result<&'a [u8]>> {
172         tokio::io::AsyncBufRead::poll_fill_buf(self.project().inner, cx)
173     }
174 
consume(self: Pin<&mut Self>, amt: usize)175     fn consume(self: Pin<&mut Self>, amt: usize) {
176         tokio::io::AsyncBufRead::consume(self.project().inner, amt)
177     }
178 }
179 
180 impl<T> tokio::io::AsyncWrite for Compat<T>
181 where
182     T: futures_io::AsyncWrite,
183 {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>184     fn poll_write(
185         self: Pin<&mut Self>,
186         cx: &mut Context<'_>,
187         buf: &[u8],
188     ) -> Poll<io::Result<usize>> {
189         futures_io::AsyncWrite::poll_write(self.project().inner, cx, buf)
190     }
191 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>192     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193         futures_io::AsyncWrite::poll_flush(self.project().inner, cx)
194     }
195 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>196     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
197         futures_io::AsyncWrite::poll_close(self.project().inner, cx)
198     }
199 }
200 
201 impl<T> futures_io::AsyncWrite for Compat<T>
202 where
203     T: tokio::io::AsyncWrite,
204 {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>205     fn poll_write(
206         self: Pin<&mut Self>,
207         cx: &mut Context<'_>,
208         buf: &[u8],
209     ) -> Poll<io::Result<usize>> {
210         tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
211     }
212 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>213     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
214         tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
215     }
216 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>217     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
218         tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
219     }
220 }
221 
222 impl<T: tokio::io::AsyncSeek> futures_io::AsyncSeek for Compat<T> {
poll_seek( mut self: Pin<&mut Self>, cx: &mut Context<'_>, pos: io::SeekFrom, ) -> Poll<io::Result<u64>>223     fn poll_seek(
224         mut self: Pin<&mut Self>,
225         cx: &mut Context<'_>,
226         pos: io::SeekFrom,
227     ) -> Poll<io::Result<u64>> {
228         if self.seek_pos != Some(pos) {
229             // Ensure previous seeks have finished before starting a new one
230             ready!(self.as_mut().project().inner.poll_complete(cx))?;
231             self.as_mut().project().inner.start_seek(pos)?;
232             *self.as_mut().project().seek_pos = Some(pos);
233         }
234         let res = ready!(self.as_mut().project().inner.poll_complete(cx));
235         *self.as_mut().project().seek_pos = None;
236         Poll::Ready(res)
237     }
238 }
239 
240 impl<T: futures_io::AsyncSeek> tokio::io::AsyncSeek for Compat<T> {
start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()>241     fn start_seek(mut self: Pin<&mut Self>, pos: io::SeekFrom) -> io::Result<()> {
242         *self.as_mut().project().seek_pos = Some(pos);
243         Ok(())
244     }
245 
poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>246     fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
247         let pos = match self.seek_pos {
248             None => {
249                 // tokio 1.x AsyncSeek recommends calling poll_complete before start_seek.
250                 // We don't have to guarantee that the value returned by
251                 // poll_complete called without start_seek is correct,
252                 // so we'll return 0.
253                 return Poll::Ready(Ok(0));
254             }
255             Some(pos) => pos,
256         };
257         let res = ready!(self.as_mut().project().inner.poll_seek(cx, pos));
258         *self.as_mut().project().seek_pos = None;
259         Poll::Ready(res)
260     }
261 }
262 
263 #[cfg(unix)]
264 impl<T: std::os::unix::io::AsRawFd> std::os::unix::io::AsRawFd for Compat<T> {
as_raw_fd(&self) -> std::os::unix::io::RawFd265     fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
266         self.inner.as_raw_fd()
267     }
268 }
269 
270 #[cfg(windows)]
271 impl<T: std::os::windows::io::AsRawHandle> std::os::windows::io::AsRawHandle for Compat<T> {
as_raw_handle(&self) -> std::os::windows::io::RawHandle272     fn as_raw_handle(&self) -> std::os::windows::io::RawHandle {
273         self.inner.as_raw_handle()
274     }
275 }
276