1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use bytes::BytesMut;
5 use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
6 use tokio_test::assert_ok;
7 
8 use std::pin::Pin;
9 use std::task::{ready, Context, Poll};
10 
11 #[tokio::test]
copy()12 async fn copy() {
13     struct Rd(bool);
14 
15     impl AsyncRead for Rd {
16         fn poll_read(
17             mut self: Pin<&mut Self>,
18             _cx: &mut Context<'_>,
19             buf: &mut ReadBuf<'_>,
20         ) -> Poll<io::Result<()>> {
21             if self.0 {
22                 buf.put_slice(b"hello world");
23                 self.0 = false;
24                 Poll::Ready(Ok(()))
25             } else {
26                 Poll::Ready(Ok(()))
27             }
28         }
29     }
30 
31     let mut rd = Rd(true);
32     let mut wr = Vec::new();
33 
34     let n = assert_ok!(io::copy(&mut rd, &mut wr).await);
35     assert_eq!(n, 11);
36     assert_eq!(wr, b"hello world");
37 }
38 
39 #[tokio::test]
proxy()40 async fn proxy() {
41     struct BufferedWd {
42         buf: BytesMut,
43         writer: io::DuplexStream,
44     }
45 
46     impl AsyncWrite for BufferedWd {
47         fn poll_write(
48             self: Pin<&mut Self>,
49             _cx: &mut Context<'_>,
50             buf: &[u8],
51         ) -> Poll<io::Result<usize>> {
52             self.get_mut().buf.extend_from_slice(buf);
53             Poll::Ready(Ok(buf.len()))
54         }
55 
56         fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
57             let this = self.get_mut();
58 
59             while !this.buf.is_empty() {
60                 let n = ready!(Pin::new(&mut this.writer).poll_write(cx, &this.buf))?;
61                 let _ = this.buf.split_to(n);
62             }
63 
64             Pin::new(&mut this.writer).poll_flush(cx)
65         }
66 
67         fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68             Pin::new(&mut self.writer).poll_shutdown(cx)
69         }
70     }
71 
72     let (rd, wd) = io::duplex(1024);
73     let mut rd = rd.take(1024);
74     let mut wd = BufferedWd {
75         buf: BytesMut::new(),
76         writer: wd,
77     };
78 
79     // write start bytes
80     assert_ok!(wd.write_all(&[0x42; 512]).await);
81     assert_ok!(wd.flush().await);
82 
83     let n = assert_ok!(io::copy(&mut rd, &mut wd).await);
84 
85     assert_eq!(n, 1024);
86 }
87 
88 #[tokio::test]
copy_is_cooperative()89 async fn copy_is_cooperative() {
90     tokio::select! {
91         biased;
92         _ = async {
93             loop {
94                 let mut reader: &[u8] = b"hello";
95                 let mut writer: Vec<u8> = vec![];
96                 let _ = io::copy(&mut reader, &mut writer).await;
97             }
98         } => {},
99         _ = tokio::task::yield_now() => {}
100     }
101 }
102