1 use std::{
2     future::poll_fn,
3     io::IoSlice,
4     pin::Pin,
5     task::{Context, Poll},
6 };
7 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
8 use tokio_util::io::{InspectReader, InspectWriter};
9 
10 /// An AsyncRead implementation that works byte-by-byte, to catch out callers
11 /// who don't allow for `buf` being part-filled before the call
12 struct SmallReader {
13     contents: Vec<u8>,
14 }
15 
16 impl Unpin for SmallReader {}
17 
18 impl AsyncRead for SmallReader {
poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>19     fn poll_read(
20         mut self: Pin<&mut Self>,
21         _cx: &mut Context<'_>,
22         buf: &mut ReadBuf<'_>,
23     ) -> Poll<std::io::Result<()>> {
24         if let Some(byte) = self.contents.pop() {
25             buf.put_slice(&[byte])
26         }
27         Poll::Ready(Ok(()))
28     }
29 }
30 
31 #[tokio::test]
read_tee()32 async fn read_tee() {
33     let contents = b"This could be really long, you know".to_vec();
34     let reader = SmallReader {
35         contents: contents.clone(),
36     };
37     let mut altout: Vec<u8> = Vec::new();
38     let mut teeout = Vec::new();
39     {
40         let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes));
41         tee.read_to_end(&mut teeout).await.unwrap();
42     }
43     assert_eq!(teeout, altout);
44     assert_eq!(altout.len(), contents.len());
45 }
46 
47 /// An AsyncWrite implementation that works byte-by-byte for poll_write, and
48 /// that reads the whole of the first buffer plus one byte from the second in
49 /// poll_write_vectored.
50 ///
51 /// This is designed to catch bugs in handling partially written buffers
52 #[derive(Debug)]
53 struct SmallWriter {
54     contents: Vec<u8>,
55 }
56 
57 impl Unpin for SmallWriter {}
58 
59 impl AsyncWrite for SmallWriter {
poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, std::io::Error>>60     fn poll_write(
61         mut self: Pin<&mut Self>,
62         _cx: &mut Context<'_>,
63         buf: &[u8],
64     ) -> Poll<Result<usize, std::io::Error>> {
65         // Just write one byte at a time
66         if buf.is_empty() {
67             return Poll::Ready(Ok(0));
68         }
69         self.contents.push(buf[0]);
70         Poll::Ready(Ok(1))
71     }
72 
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>>73     fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
74         Poll::Ready(Ok(()))
75     }
76 
poll_shutdown( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<(), std::io::Error>>77     fn poll_shutdown(
78         self: Pin<&mut Self>,
79         _cx: &mut Context<'_>,
80     ) -> Poll<Result<(), std::io::Error>> {
81         Poll::Ready(Ok(()))
82     }
83 
poll_write_vectored( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, std::io::Error>>84     fn poll_write_vectored(
85         mut self: Pin<&mut Self>,
86         _cx: &mut Context<'_>,
87         bufs: &[IoSlice<'_>],
88     ) -> Poll<Result<usize, std::io::Error>> {
89         // Write all of the first buffer, then one byte from the second buffer
90         // This should trip up anything that doesn't correctly handle multiple
91         // buffers.
92         if bufs.is_empty() {
93             return Poll::Ready(Ok(0));
94         }
95         let mut written_len = bufs[0].len();
96         self.contents.extend_from_slice(&bufs[0]);
97 
98         if bufs.len() > 1 {
99             let buf = bufs[1];
100             if !buf.is_empty() {
101                 written_len += 1;
102                 self.contents.push(buf[0]);
103             }
104         }
105         Poll::Ready(Ok(written_len))
106     }
107 
is_write_vectored(&self) -> bool108     fn is_write_vectored(&self) -> bool {
109         true
110     }
111 }
112 
113 #[tokio::test]
write_tee()114 async fn write_tee() {
115     let mut altout: Vec<u8> = Vec::new();
116     let mut writeout = SmallWriter {
117         contents: Vec::new(),
118     };
119     {
120         let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes));
121         tee.write_all(b"A testing string, very testing")
122             .await
123             .unwrap();
124     }
125     assert_eq!(altout, writeout.contents);
126 }
127 
128 // This is inefficient, but works well enough for test use.
129 // If you want something similar for real code, you'll want to avoid all the
130 // fun of manipulating `bufs` - ideally, by the time you read this,
131 // IoSlice::advance_slices will be stable, and you can use that.
write_all_vectored<W: AsyncWrite + Unpin>( mut writer: W, mut bufs: Vec<Vec<u8>>, ) -> Result<usize, std::io::Error>132 async fn write_all_vectored<W: AsyncWrite + Unpin>(
133     mut writer: W,
134     mut bufs: Vec<Vec<u8>>,
135 ) -> Result<usize, std::io::Error> {
136     let mut res = 0;
137     while !bufs.is_empty() {
138         let mut written = poll_fn(|cx| {
139             let bufs: Vec<IoSlice> = bufs.iter().map(|v| IoSlice::new(v)).collect();
140             Pin::new(&mut writer).poll_write_vectored(cx, &bufs)
141         })
142         .await?;
143         res += written;
144         while written > 0 {
145             let buf_len = bufs[0].len();
146             if buf_len <= written {
147                 bufs.remove(0);
148                 written -= buf_len;
149             } else {
150                 let buf = &mut bufs[0];
151                 let drain_len = written.min(buf.len());
152                 buf.drain(..drain_len);
153                 written -= drain_len;
154             }
155         }
156     }
157     Ok(res)
158 }
159 
160 #[tokio::test]
write_tee_vectored()161 async fn write_tee_vectored() {
162     let mut altout: Vec<u8> = Vec::new();
163     let mut writeout = SmallWriter {
164         contents: Vec::new(),
165     };
166     let original = b"A very long string split up";
167     let bufs: Vec<Vec<u8>> = original
168         .split(|b| b.is_ascii_whitespace())
169         .map(Vec::from)
170         .collect();
171     assert!(bufs.len() > 1);
172     let expected: Vec<u8> = {
173         let mut out = Vec::new();
174         for item in &bufs {
175             out.extend_from_slice(item)
176         }
177         out
178     };
179     {
180         let mut bufcount = 0;
181         let tee = InspectWriter::new(&mut writeout, |bytes| {
182             bufcount += 1;
183             altout.extend(bytes)
184         });
185 
186         assert!(tee.is_write_vectored());
187 
188         write_all_vectored(tee, bufs.clone()).await.unwrap();
189 
190         assert!(bufcount >= bufs.len());
191     }
192     assert_eq!(altout, writeout.contents);
193     assert_eq!(writeout.contents, expected);
194 }
195