1 #![warn(rust_2018_idioms)]
2 #![cfg(all(feature = "full", not(target_os = "wasi")))] // Wasi doesn't support bind
3 
4 use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
5 use tokio::net::{TcpListener, TcpStream};
6 use tokio::try_join;
7 use tokio_test::task;
8 use tokio_test::{assert_ok, assert_pending, assert_ready_ok};
9 
10 use std::future::poll_fn;
11 use std::io;
12 use std::task::Poll;
13 use std::time::Duration;
14 
15 #[tokio::test]
16 #[cfg_attr(miri, ignore)] // No `socket` on miri.
set_linger()17 async fn set_linger() {
18     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
19 
20     let stream = TcpStream::connect(listener.local_addr().unwrap())
21         .await
22         .unwrap();
23 
24     assert_ok!(stream.set_linger(Some(Duration::from_secs(1))));
25     assert_eq!(stream.linger().unwrap().unwrap().as_secs(), 1);
26 
27     assert_ok!(stream.set_linger(None));
28     assert!(stream.linger().unwrap().is_none());
29 }
30 
31 #[tokio::test]
32 #[cfg_attr(miri, ignore)] // No `socket` on miri.
try_read_write()33 async fn try_read_write() {
34     const DATA: &[u8] = b"this is some data to write to the socket";
35 
36     // Create listener
37     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
38 
39     // Create socket pair
40     let client = TcpStream::connect(listener.local_addr().unwrap())
41         .await
42         .unwrap();
43     let (server, _) = listener.accept().await.unwrap();
44     let mut written = DATA.to_vec();
45 
46     // Track the server receiving data
47     let mut readable = task::spawn(server.readable());
48     assert_pending!(readable.poll());
49 
50     // Write data.
51     client.writable().await.unwrap();
52     assert_eq!(DATA.len(), client.try_write(DATA).unwrap());
53 
54     // The task should be notified
55     while !readable.is_woken() {
56         tokio::task::yield_now().await;
57     }
58 
59     // Fill the write buffer using non-vectored I/O
60     loop {
61         // Still ready
62         let mut writable = task::spawn(client.writable());
63         assert_ready_ok!(writable.poll());
64 
65         match client.try_write(DATA) {
66             Ok(n) => written.extend(&DATA[..n]),
67             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
68                 break;
69             }
70             Err(e) => panic!("error = {e:?}"),
71         }
72     }
73 
74     {
75         // Write buffer full
76         let mut writable = task::spawn(client.writable());
77         assert_pending!(writable.poll());
78 
79         // Drain the socket from the server end using non-vectored I/O
80         let mut read = vec![0; written.len()];
81         let mut i = 0;
82 
83         while i < read.len() {
84             server.readable().await.unwrap();
85 
86             match server.try_read(&mut read[i..]) {
87                 Ok(n) => i += n,
88                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
89                 Err(e) => panic!("error = {e:?}"),
90             }
91         }
92 
93         assert_eq!(read, written);
94     }
95 
96     written.clear();
97     client.writable().await.unwrap();
98 
99     // Fill the write buffer using vectored I/O
100     let data_bufs: Vec<_> = DATA.chunks(10).map(io::IoSlice::new).collect();
101     loop {
102         // Still ready
103         let mut writable = task::spawn(client.writable());
104         assert_ready_ok!(writable.poll());
105 
106         match client.try_write_vectored(&data_bufs) {
107             Ok(n) => written.extend(&DATA[..n]),
108             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
109                 break;
110             }
111             Err(e) => panic!("error = {e:?}"),
112         }
113     }
114 
115     {
116         // Write buffer full
117         let mut writable = task::spawn(client.writable());
118         assert_pending!(writable.poll());
119 
120         // Drain the socket from the server end using vectored I/O
121         let mut read = vec![0; written.len()];
122         let mut i = 0;
123 
124         while i < read.len() {
125             server.readable().await.unwrap();
126 
127             let mut bufs: Vec<_> = read[i..]
128                 .chunks_mut(0x10000)
129                 .map(io::IoSliceMut::new)
130                 .collect();
131             match server.try_read_vectored(&mut bufs) {
132                 Ok(n) => i += n,
133                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
134                 Err(e) => panic!("error = {e:?}"),
135             }
136         }
137 
138         assert_eq!(read, written);
139     }
140 
141     // Now, we listen for shutdown
142     drop(client);
143 
144     loop {
145         let ready = server.ready(Interest::READABLE).await.unwrap();
146 
147         if ready.is_read_closed() {
148             return;
149         } else {
150             tokio::task::yield_now().await;
151         }
152     }
153 }
154 
155 #[test]
buffer_not_included_in_future()156 fn buffer_not_included_in_future() {
157     use std::mem;
158 
159     const N: usize = 4096;
160 
161     let fut = async {
162         let stream = TcpStream::connect("127.0.0.1:8080").await.unwrap();
163 
164         loop {
165             stream.readable().await.unwrap();
166 
167             let mut buf = [0; N];
168             let n = stream.try_read(&mut buf[..]).unwrap();
169 
170             if n == 0 {
171                 break;
172             }
173         }
174     };
175 
176     let n = mem::size_of_val(&fut);
177     assert!(n < 1000);
178 }
179 
180 macro_rules! assert_readable_by_polling {
181     ($stream:expr) => {
182         assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
183     };
184 }
185 
186 macro_rules! assert_not_readable_by_polling {
187     ($stream:expr) => {
188         poll_fn(|cx| {
189             assert_pending!($stream.poll_read_ready(cx));
190             Poll::Ready(())
191         })
192         .await;
193     };
194 }
195 
196 macro_rules! assert_writable_by_polling {
197     ($stream:expr) => {
198         assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
199     };
200 }
201 
202 macro_rules! assert_not_writable_by_polling {
203     ($stream:expr) => {
204         poll_fn(|cx| {
205             assert_pending!($stream.poll_write_ready(cx));
206             Poll::Ready(())
207         })
208         .await;
209     };
210 }
211 
212 #[tokio::test]
213 #[cfg_attr(miri, ignore)] // No `socket` on miri.
poll_read_ready()214 async fn poll_read_ready() {
215     let (mut client, mut server) = create_pair().await;
216 
217     // Initial state - not readable.
218     assert_not_readable_by_polling!(server);
219 
220     // There is data in the buffer - readable.
221     assert_ok!(client.write_all(b"ping").await);
222     assert_readable_by_polling!(server);
223 
224     // Readable until calls to `poll_read` return `Poll::Pending`.
225     let mut buf = [0u8; 4];
226     assert_ok!(server.read_exact(&mut buf).await);
227     assert_readable_by_polling!(server);
228     read_until_pending(&mut server);
229     assert_not_readable_by_polling!(server);
230 
231     // Detect the client disconnect.
232     drop(client);
233     assert_readable_by_polling!(server);
234 }
235 
236 #[tokio::test]
237 #[cfg_attr(miri, ignore)] // No `socket` on miri.
poll_write_ready()238 async fn poll_write_ready() {
239     let (mut client, server) = create_pair().await;
240 
241     // Initial state - writable.
242     assert_writable_by_polling!(client);
243 
244     // No space to write - not writable.
245     write_until_pending(&mut client);
246     assert_not_writable_by_polling!(client);
247 
248     // Detect the server disconnect.
249     drop(server);
250     assert_writable_by_polling!(client);
251 }
252 
create_pair() -> (TcpStream, TcpStream)253 async fn create_pair() -> (TcpStream, TcpStream) {
254     let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
255     let addr = assert_ok!(listener.local_addr());
256     let (client, (server, _)) = assert_ok!(try_join!(TcpStream::connect(&addr), listener.accept()));
257     (client, server)
258 }
259 
read_until_pending(stream: &mut TcpStream) -> usize260 fn read_until_pending(stream: &mut TcpStream) -> usize {
261     let mut buf = vec![0u8; 1024 * 1024];
262     let mut total = 0;
263     loop {
264         match stream.try_read(&mut buf) {
265             Ok(n) => total += n,
266             Err(err) => {
267                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
268                 break;
269             }
270         }
271     }
272     total
273 }
274 
write_until_pending(stream: &mut TcpStream) -> usize275 fn write_until_pending(stream: &mut TcpStream) -> usize {
276     let buf = vec![0u8; 1024 * 1024];
277     let mut total = 0;
278     loop {
279         match stream.try_write(&buf) {
280             Ok(n) => total += n,
281             Err(err) => {
282                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
283                 break;
284             }
285         }
286     }
287     total
288 }
289 
290 #[tokio::test]
291 #[cfg_attr(miri, ignore)] // No `socket` on miri.
try_read_buf()292 async fn try_read_buf() {
293     const DATA: &[u8] = b"this is some data to write to the socket";
294 
295     // Create listener
296     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
297 
298     // Create socket pair
299     let client = TcpStream::connect(listener.local_addr().unwrap())
300         .await
301         .unwrap();
302     let (server, _) = listener.accept().await.unwrap();
303     let mut written = DATA.to_vec();
304 
305     // Track the server receiving data
306     let mut readable = task::spawn(server.readable());
307     assert_pending!(readable.poll());
308 
309     // Write data.
310     client.writable().await.unwrap();
311     assert_eq!(DATA.len(), client.try_write(DATA).unwrap());
312 
313     // The task should be notified
314     while !readable.is_woken() {
315         tokio::task::yield_now().await;
316     }
317 
318     // Fill the write buffer
319     loop {
320         // Still ready
321         let mut writable = task::spawn(client.writable());
322         assert_ready_ok!(writable.poll());
323 
324         match client.try_write(DATA) {
325             Ok(n) => written.extend(&DATA[..n]),
326             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
327                 break;
328             }
329             Err(e) => panic!("error = {e:?}"),
330         }
331     }
332 
333     {
334         // Write buffer full
335         let mut writable = task::spawn(client.writable());
336         assert_pending!(writable.poll());
337 
338         // Drain the socket from the server end
339         let mut read = Vec::with_capacity(written.len());
340         let mut i = 0;
341 
342         while i < read.capacity() {
343             server.readable().await.unwrap();
344 
345             match server.try_read_buf(&mut read) {
346                 Ok(n) => i += n,
347                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
348                 Err(e) => panic!("error = {e:?}"),
349             }
350         }
351 
352         assert_eq!(read, written);
353     }
354 
355     // Now, we listen for shutdown
356     drop(client);
357 
358     loop {
359         let ready = server.ready(Interest::READABLE).await.unwrap();
360 
361         if ready.is_read_closed() {
362             return;
363         } else {
364             tokio::task::yield_now().await;
365         }
366     }
367 }
368 
369 // read_closed is a best effort event, so test only for no false positives.
370 #[tokio::test]
371 #[cfg_attr(miri, ignore)] // No `socket` on miri.
read_closed()372 async fn read_closed() {
373     let (client, mut server) = create_pair().await;
374 
375     let mut ready_fut = task::spawn(client.ready(Interest::READABLE));
376     assert_pending!(ready_fut.poll());
377 
378     assert_ok!(server.write_all(b"ping").await);
379 
380     let ready_event = assert_ok!(ready_fut.await);
381 
382     assert!(!ready_event.is_read_closed());
383 }
384 
385 // write_closed is a best effort event, so test only for no false positives.
386 #[tokio::test]
387 #[cfg_attr(miri, ignore)] // No `socket` on miri.
write_closed()388 async fn write_closed() {
389     let (mut client, mut server) = create_pair().await;
390 
391     // Fill the write buffer.
392     let write_size = write_until_pending(&mut client);
393     let mut ready_fut = task::spawn(client.ready(Interest::WRITABLE));
394     assert_pending!(ready_fut.poll());
395 
396     // Drain the socket to make client writable.
397     let mut read_size = 0;
398     while read_size < write_size {
399         server.readable().await.unwrap();
400         read_size += read_until_pending(&mut server);
401     }
402 
403     let ready_event = assert_ok!(ready_fut.await);
404 
405     assert!(!ready_event.is_write_closed());
406 }
407