1 use std::{
2 future::poll_fn,
3 io::IoSlice,
4 pin::Pin,
5 task::{Context, Poll},
6 };
7 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
8 use tokio_util::io::{InspectReader, InspectWriter};
9
10 /// An AsyncRead implementation that works byte-by-byte, to catch out callers
11 /// who don't allow for `buf` being part-filled before the call
12 struct SmallReader {
13 contents: Vec<u8>,
14 }
15
16 impl Unpin for SmallReader {}
17
18 impl AsyncRead for SmallReader {
poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>19 fn poll_read(
20 mut self: Pin<&mut Self>,
21 _cx: &mut Context<'_>,
22 buf: &mut ReadBuf<'_>,
23 ) -> Poll<std::io::Result<()>> {
24 if let Some(byte) = self.contents.pop() {
25 buf.put_slice(&[byte])
26 }
27 Poll::Ready(Ok(()))
28 }
29 }
30
31 #[tokio::test]
read_tee()32 async fn read_tee() {
33 let contents = b"This could be really long, you know".to_vec();
34 let reader = SmallReader {
35 contents: contents.clone(),
36 };
37 let mut altout: Vec<u8> = Vec::new();
38 let mut teeout = Vec::new();
39 {
40 let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes));
41 tee.read_to_end(&mut teeout).await.unwrap();
42 }
43 assert_eq!(teeout, altout);
44 assert_eq!(altout.len(), contents.len());
45 }
46
47 /// An AsyncWrite implementation that works byte-by-byte for poll_write, and
48 /// that reads the whole of the first buffer plus one byte from the second in
49 /// poll_write_vectored.
50 ///
51 /// This is designed to catch bugs in handling partially written buffers
52 #[derive(Debug)]
53 struct SmallWriter {
54 contents: Vec<u8>,
55 }
56
57 impl Unpin for SmallWriter {}
58
59 impl AsyncWrite for SmallWriter {
poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, std::io::Error>>60 fn poll_write(
61 mut self: Pin<&mut Self>,
62 _cx: &mut Context<'_>,
63 buf: &[u8],
64 ) -> Poll<Result<usize, std::io::Error>> {
65 // Just write one byte at a time
66 if buf.is_empty() {
67 return Poll::Ready(Ok(0));
68 }
69 self.contents.push(buf[0]);
70 Poll::Ready(Ok(1))
71 }
72
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>>73 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
74 Poll::Ready(Ok(()))
75 }
76
poll_shutdown( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<(), std::io::Error>>77 fn poll_shutdown(
78 self: Pin<&mut Self>,
79 _cx: &mut Context<'_>,
80 ) -> Poll<Result<(), std::io::Error>> {
81 Poll::Ready(Ok(()))
82 }
83
poll_write_vectored( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, std::io::Error>>84 fn poll_write_vectored(
85 mut self: Pin<&mut Self>,
86 _cx: &mut Context<'_>,
87 bufs: &[IoSlice<'_>],
88 ) -> Poll<Result<usize, std::io::Error>> {
89 // Write all of the first buffer, then one byte from the second buffer
90 // This should trip up anything that doesn't correctly handle multiple
91 // buffers.
92 if bufs.is_empty() {
93 return Poll::Ready(Ok(0));
94 }
95 let mut written_len = bufs[0].len();
96 self.contents.extend_from_slice(&bufs[0]);
97
98 if bufs.len() > 1 {
99 let buf = bufs[1];
100 if !buf.is_empty() {
101 written_len += 1;
102 self.contents.push(buf[0]);
103 }
104 }
105 Poll::Ready(Ok(written_len))
106 }
107
is_write_vectored(&self) -> bool108 fn is_write_vectored(&self) -> bool {
109 true
110 }
111 }
112
113 #[tokio::test]
write_tee()114 async fn write_tee() {
115 let mut altout: Vec<u8> = Vec::new();
116 let mut writeout = SmallWriter {
117 contents: Vec::new(),
118 };
119 {
120 let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes));
121 tee.write_all(b"A testing string, very testing")
122 .await
123 .unwrap();
124 }
125 assert_eq!(altout, writeout.contents);
126 }
127
128 // This is inefficient, but works well enough for test use.
129 // If you want something similar for real code, you'll want to avoid all the
130 // fun of manipulating `bufs` - ideally, by the time you read this,
131 // IoSlice::advance_slices will be stable, and you can use that.
write_all_vectored<W: AsyncWrite + Unpin>( mut writer: W, mut bufs: Vec<Vec<u8>>, ) -> Result<usize, std::io::Error>132 async fn write_all_vectored<W: AsyncWrite + Unpin>(
133 mut writer: W,
134 mut bufs: Vec<Vec<u8>>,
135 ) -> Result<usize, std::io::Error> {
136 let mut res = 0;
137 while !bufs.is_empty() {
138 let mut written = poll_fn(|cx| {
139 let bufs: Vec<IoSlice> = bufs.iter().map(|v| IoSlice::new(v)).collect();
140 Pin::new(&mut writer).poll_write_vectored(cx, &bufs)
141 })
142 .await?;
143 res += written;
144 while written > 0 {
145 let buf_len = bufs[0].len();
146 if buf_len <= written {
147 bufs.remove(0);
148 written -= buf_len;
149 } else {
150 let buf = &mut bufs[0];
151 let drain_len = written.min(buf.len());
152 buf.drain(..drain_len);
153 written -= drain_len;
154 }
155 }
156 }
157 Ok(res)
158 }
159
160 #[tokio::test]
write_tee_vectored()161 async fn write_tee_vectored() {
162 let mut altout: Vec<u8> = Vec::new();
163 let mut writeout = SmallWriter {
164 contents: Vec::new(),
165 };
166 let original = b"A very long string split up";
167 let bufs: Vec<Vec<u8>> = original
168 .split(|b| b.is_ascii_whitespace())
169 .map(Vec::from)
170 .collect();
171 assert!(bufs.len() > 1);
172 let expected: Vec<u8> = {
173 let mut out = Vec::new();
174 for item in &bufs {
175 out.extend_from_slice(item)
176 }
177 out
178 };
179 {
180 let mut bufcount = 0;
181 let tee = InspectWriter::new(&mut writeout, |bytes| {
182 bufcount += 1;
183 altout.extend(bytes)
184 });
185
186 assert!(tee.is_write_vectored());
187
188 write_all_vectored(tee, bufs.clone()).await.unwrap();
189
190 assert!(bufcount >= bufs.len());
191 }
192 assert_eq!(altout, writeout.contents);
193 assert_eq!(writeout.contents, expected);
194 }
195