1 #![cfg(feature = "full")]
2 #![warn(rust_2018_idioms)]
3 #![cfg(unix)]
4 #![cfg(not(miri))] // No socket in miri.
5 
6 use std::io;
7 #[cfg(target_os = "android")]
8 use std::os::android::net::SocketAddrExt;
9 #[cfg(target_os = "linux")]
10 use std::os::linux::net::SocketAddrExt;
11 use std::task::Poll;
12 
13 use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
14 use tokio::net::{UnixListener, UnixStream};
15 use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task};
16 
17 use futures::future::{poll_fn, try_join};
18 
19 #[tokio::test]
accept_read_write() -> std::io::Result<()>20 async fn accept_read_write() -> std::io::Result<()> {
21     let dir = tempfile::Builder::new()
22         .prefix("tokio-uds-tests")
23         .tempdir()
24         .unwrap();
25     let sock_path = dir.path().join("connect.sock");
26 
27     let listener = UnixListener::bind(&sock_path)?;
28 
29     let accept = listener.accept();
30     let connect = UnixStream::connect(&sock_path);
31     let ((mut server, _), mut client) = try_join(accept, connect).await?;
32 
33     // Write to the client.
34     client.write_all(b"hello").await?;
35     drop(client);
36 
37     // Read from the server.
38     let mut buf = vec![];
39     server.read_to_end(&mut buf).await?;
40     assert_eq!(&buf, b"hello");
41     let len = server.read(&mut buf).await?;
42     assert_eq!(len, 0);
43     Ok(())
44 }
45 
46 #[tokio::test]
shutdown() -> std::io::Result<()>47 async fn shutdown() -> std::io::Result<()> {
48     let dir = tempfile::Builder::new()
49         .prefix("tokio-uds-tests")
50         .tempdir()
51         .unwrap();
52     let sock_path = dir.path().join("connect.sock");
53 
54     let listener = UnixListener::bind(&sock_path)?;
55 
56     let accept = listener.accept();
57     let connect = UnixStream::connect(&sock_path);
58     let ((mut server, _), mut client) = try_join(accept, connect).await?;
59 
60     // Shut down the client
61     AsyncWriteExt::shutdown(&mut client).await?;
62     // Read from the server should return 0 to indicate the channel has been closed.
63     let mut buf = [0u8; 1];
64     let n = server.read(&mut buf).await?;
65     assert_eq!(n, 0);
66     Ok(())
67 }
68 
69 #[tokio::test]
try_read_write() -> std::io::Result<()>70 async fn try_read_write() -> std::io::Result<()> {
71     let msg = b"hello world";
72 
73     let dir = tempfile::tempdir()?;
74     let bind_path = dir.path().join("bind.sock");
75 
76     // Create listener
77     let listener = UnixListener::bind(&bind_path)?;
78 
79     // Create socket pair
80     let client = UnixStream::connect(&bind_path).await?;
81 
82     let (server, _) = listener.accept().await?;
83     let mut written = msg.to_vec();
84 
85     // Track the server receiving data
86     let mut readable = task::spawn(server.readable());
87     assert_pending!(readable.poll());
88 
89     // Write data.
90     client.writable().await?;
91     assert_eq!(msg.len(), client.try_write(msg)?);
92 
93     // The task should be notified
94     while !readable.is_woken() {
95         tokio::task::yield_now().await;
96     }
97 
98     // Fill the write buffer using non-vectored I/O
99     loop {
100         // Still ready
101         let mut writable = task::spawn(client.writable());
102         assert_ready_ok!(writable.poll());
103 
104         match client.try_write(msg) {
105             Ok(n) => written.extend(&msg[..n]),
106             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
107                 break;
108             }
109             Err(e) => panic!("error = {e:?}"),
110         }
111     }
112 
113     {
114         // Write buffer full
115         let mut writable = task::spawn(client.writable());
116         assert_pending!(writable.poll());
117 
118         // Drain the socket from the server end using non-vectored I/O
119         let mut read = vec![0; written.len()];
120         let mut i = 0;
121 
122         while i < read.len() {
123             server.readable().await?;
124 
125             match server.try_read(&mut read[i..]) {
126                 Ok(n) => i += n,
127                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
128                 Err(e) => panic!("error = {e:?}"),
129             }
130         }
131 
132         assert_eq!(read, written);
133     }
134 
135     written.clear();
136     client.writable().await.unwrap();
137 
138     // Fill the write buffer using vectored I/O
139     let msg_bufs: Vec<_> = msg.chunks(3).map(io::IoSlice::new).collect();
140     loop {
141         // Still ready
142         let mut writable = task::spawn(client.writable());
143         assert_ready_ok!(writable.poll());
144 
145         match client.try_write_vectored(&msg_bufs) {
146             Ok(n) => written.extend(&msg[..n]),
147             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
148                 break;
149             }
150             Err(e) => panic!("error = {e:?}"),
151         }
152     }
153 
154     {
155         // Write buffer full
156         let mut writable = task::spawn(client.writable());
157         assert_pending!(writable.poll());
158 
159         // Drain the socket from the server end using vectored I/O
160         let mut read = vec![0; written.len()];
161         let mut i = 0;
162 
163         while i < read.len() {
164             server.readable().await?;
165 
166             let mut bufs: Vec<_> = read[i..]
167                 .chunks_mut(0x10000)
168                 .map(io::IoSliceMut::new)
169                 .collect();
170             match server.try_read_vectored(&mut bufs) {
171                 Ok(n) => i += n,
172                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
173                 Err(e) => panic!("error = {e:?}"),
174             }
175         }
176 
177         assert_eq!(read, written);
178     }
179 
180     // Now, we listen for shutdown
181     drop(client);
182 
183     loop {
184         let ready = server.ready(Interest::READABLE).await?;
185 
186         if ready.is_read_closed() {
187             break;
188         } else {
189             tokio::task::yield_now().await;
190         }
191     }
192 
193     Ok(())
194 }
195 
create_pair() -> (UnixStream, UnixStream)196 async fn create_pair() -> (UnixStream, UnixStream) {
197     let dir = assert_ok!(tempfile::tempdir());
198     let bind_path = dir.path().join("bind.sock");
199 
200     let listener = assert_ok!(UnixListener::bind(&bind_path));
201 
202     let accept = listener.accept();
203     let connect = UnixStream::connect(&bind_path);
204     let ((server, _), client) = assert_ok!(try_join(accept, connect).await);
205 
206     (client, server)
207 }
208 
209 macro_rules! assert_readable_by_polling {
210     ($stream:expr) => {
211         assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
212     };
213 }
214 
215 macro_rules! assert_not_readable_by_polling {
216     ($stream:expr) => {
217         poll_fn(|cx| {
218             assert_pending!($stream.poll_read_ready(cx));
219             Poll::Ready(())
220         })
221         .await;
222     };
223 }
224 
225 macro_rules! assert_writable_by_polling {
226     ($stream:expr) => {
227         assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
228     };
229 }
230 
231 macro_rules! assert_not_writable_by_polling {
232     ($stream:expr) => {
233         poll_fn(|cx| {
234             assert_pending!($stream.poll_write_ready(cx));
235             Poll::Ready(())
236         })
237         .await;
238     };
239 }
240 
241 #[tokio::test]
poll_read_ready()242 async fn poll_read_ready() {
243     let (mut client, mut server) = create_pair().await;
244 
245     // Initial state - not readable.
246     assert_not_readable_by_polling!(server);
247 
248     // There is data in the buffer - readable.
249     assert_ok!(client.write_all(b"ping").await);
250     assert_readable_by_polling!(server);
251 
252     // Readable until calls to `poll_read` return `Poll::Pending`.
253     let mut buf = [0u8; 4];
254     assert_ok!(server.read_exact(&mut buf).await);
255     assert_readable_by_polling!(server);
256     read_until_pending(&mut server);
257     assert_not_readable_by_polling!(server);
258 
259     // Detect the client disconnect.
260     drop(client);
261     assert_readable_by_polling!(server);
262 }
263 
264 #[tokio::test]
poll_write_ready()265 async fn poll_write_ready() {
266     let (mut client, server) = create_pair().await;
267 
268     // Initial state - writable.
269     assert_writable_by_polling!(client);
270 
271     // No space to write - not writable.
272     write_until_pending(&mut client);
273     assert_not_writable_by_polling!(client);
274 
275     // Detect the server disconnect.
276     drop(server);
277     assert_writable_by_polling!(client);
278 }
279 
read_until_pending(stream: &mut UnixStream)280 fn read_until_pending(stream: &mut UnixStream) {
281     let mut buf = vec![0u8; 1024 * 1024];
282     loop {
283         match stream.try_read(&mut buf) {
284             Ok(_) => (),
285             Err(err) => {
286                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
287                 break;
288             }
289         }
290     }
291 }
292 
write_until_pending(stream: &mut UnixStream)293 fn write_until_pending(stream: &mut UnixStream) {
294     let buf = vec![0u8; 1024 * 1024];
295     loop {
296         match stream.try_write(&buf) {
297             Ok(_) => (),
298             Err(err) => {
299                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
300                 break;
301             }
302         }
303     }
304 }
305 
306 #[tokio::test]
try_read_buf() -> std::io::Result<()>307 async fn try_read_buf() -> std::io::Result<()> {
308     let msg = b"hello world";
309 
310     let dir = tempfile::tempdir()?;
311     let bind_path = dir.path().join("bind.sock");
312 
313     // Create listener
314     let listener = UnixListener::bind(&bind_path)?;
315 
316     // Create socket pair
317     let client = UnixStream::connect(&bind_path).await?;
318 
319     let (server, _) = listener.accept().await?;
320     let mut written = msg.to_vec();
321 
322     // Track the server receiving data
323     let mut readable = task::spawn(server.readable());
324     assert_pending!(readable.poll());
325 
326     // Write data.
327     client.writable().await?;
328     assert_eq!(msg.len(), client.try_write(msg)?);
329 
330     // The task should be notified
331     while !readable.is_woken() {
332         tokio::task::yield_now().await;
333     }
334 
335     // Fill the write buffer
336     loop {
337         // Still ready
338         let mut writable = task::spawn(client.writable());
339         assert_ready_ok!(writable.poll());
340 
341         match client.try_write(msg) {
342             Ok(n) => written.extend(&msg[..n]),
343             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
344                 break;
345             }
346             Err(e) => panic!("error = {e:?}"),
347         }
348     }
349 
350     {
351         // Write buffer full
352         let mut writable = task::spawn(client.writable());
353         assert_pending!(writable.poll());
354 
355         // Drain the socket from the server end
356         let mut read = Vec::with_capacity(written.len());
357         let mut i = 0;
358 
359         while i < read.capacity() {
360             server.readable().await?;
361 
362             match server.try_read_buf(&mut read) {
363                 Ok(n) => i += n,
364                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
365                 Err(e) => panic!("error = {e:?}"),
366             }
367         }
368 
369         assert_eq!(read, written);
370     }
371 
372     // Now, we listen for shutdown
373     drop(client);
374 
375     loop {
376         let ready = server.ready(Interest::READABLE).await?;
377 
378         if ready.is_read_closed() {
379             break;
380         } else {
381             tokio::task::yield_now().await;
382         }
383     }
384 
385     Ok(())
386 }
387 
388 // https://github.com/tokio-rs/tokio/issues/3879
389 #[tokio::test]
390 #[cfg(not(target_os = "macos"))]
epollhup() -> io::Result<()>391 async fn epollhup() -> io::Result<()> {
392     let dir = tempfile::Builder::new()
393         .prefix("tokio-uds-tests")
394         .tempdir()
395         .unwrap();
396     let sock_path = dir.path().join("connect.sock");
397 
398     let listener = UnixListener::bind(&sock_path)?;
399     let connect = UnixStream::connect(&sock_path);
400     tokio::pin!(connect);
401 
402     // Poll `connect` once.
403     poll_fn(|cx| {
404         use std::future::Future;
405 
406         assert_pending!(connect.as_mut().poll(cx));
407         Poll::Ready(())
408     })
409     .await;
410 
411     drop(listener);
412 
413     let err = connect.await.unwrap_err();
414     let errno = err.kind();
415     assert!(
416         // As far as I can tell, whether we see ECONNREFUSED or ECONNRESET here
417         // seems relatively inconsistent, at least on non-Linux operating
418         // systems. The difference in meaning between these errnos is not
419         // particularly well-defined, so let's just accept either.
420         matches!(
421             errno,
422             io::ErrorKind::ConnectionRefused | io::ErrorKind::ConnectionReset
423         ),
424         "unexpected error kind: {errno:?} (expected ConnectionRefused or ConnectionReset)"
425     );
426     Ok(())
427 }
428 
429 // test for https://github.com/tokio-rs/tokio/issues/6767
430 #[tokio::test]
431 #[cfg(any(target_os = "linux", target_os = "android"))]
abstract_socket_name()432 async fn abstract_socket_name() {
433     let socket_path = "\0aaa";
434     let listener = UnixListener::bind(socket_path).unwrap();
435 
436     let accept = listener.accept();
437     let connect = UnixStream::connect(&socket_path);
438 
439     let ((stream, _), _) = try_join(accept, connect).await.unwrap();
440 
441     let local_addr = stream.into_std().unwrap().local_addr().unwrap();
442     let abstract_path_name = local_addr.as_abstract_name().unwrap();
443 
444     // `as_abstract_name` removes leading zero bytes
445     assert_eq!(abstract_path_name, b"aaa");
446 }
447