1 #![warn(rust_2018_idioms)]
2 #![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
3 // Blocked on https://github.com/rust-lang/miri/issues/3911
4 #![cfg(not(miri))]
5 
6 use std::rc::Rc;
7 use std::sync::Arc;
8 use tokio::sync::Barrier;
9 use tokio_util::task;
10 
11 /// Simple test of running a !Send future via spawn_pinned
12 #[tokio::test]
can_spawn_not_send_future()13 async fn can_spawn_not_send_future() {
14     let pool = task::LocalPoolHandle::new(1);
15 
16     let output = pool
17         .spawn_pinned(|| {
18             // Rc is !Send + !Sync
19             let local_data = Rc::new("test");
20 
21             // This future holds an Rc, so it is !Send
22             async move { local_data.to_string() }
23         })
24         .await
25         .unwrap();
26 
27     assert_eq!(output, "test");
28 }
29 
30 /// Dropping the join handle still lets the task execute
31 #[test]
can_drop_future_and_still_get_output()32 fn can_drop_future_and_still_get_output() {
33     let pool = task::LocalPoolHandle::new(1);
34     let (sender, receiver) = std::sync::mpsc::channel();
35 
36     pool.spawn_pinned(move || {
37         // Rc is !Send + !Sync
38         let local_data = Rc::new("test");
39 
40         // This future holds an Rc, so it is !Send
41         async move {
42             let _ = sender.send(local_data.to_string());
43         }
44     });
45 
46     assert_eq!(receiver.recv(), Ok("test".to_string()));
47 }
48 
49 #[test]
50 #[should_panic(expected = "assertion failed: pool_size > 0")]
cannot_create_zero_sized_pool()51 fn cannot_create_zero_sized_pool() {
52     let _pool = task::LocalPoolHandle::new(0);
53 }
54 
55 /// We should be able to spawn multiple futures onto the pool at the same time.
56 #[tokio::test]
can_spawn_multiple_futures()57 async fn can_spawn_multiple_futures() {
58     let pool = task::LocalPoolHandle::new(2);
59 
60     let join_handle1 = pool.spawn_pinned(|| {
61         let local_data = Rc::new("test1");
62         async move { local_data.to_string() }
63     });
64     let join_handle2 = pool.spawn_pinned(|| {
65         let local_data = Rc::new("test2");
66         async move { local_data.to_string() }
67     });
68 
69     assert_eq!(join_handle1.await.unwrap(), "test1");
70     assert_eq!(join_handle2.await.unwrap(), "test2");
71 }
72 
73 /// A panic in the spawned task causes the join handle to return an error.
74 /// But, you can continue to spawn tasks.
75 #[tokio::test]
76 #[cfg(panic = "unwind")]
task_panic_propagates()77 async fn task_panic_propagates() {
78     let pool = task::LocalPoolHandle::new(1);
79 
80     let join_handle = pool.spawn_pinned(|| async {
81         panic!("Test panic");
82     });
83 
84     let result = join_handle.await;
85     assert!(result.is_err());
86     let error = result.unwrap_err();
87     assert!(error.is_panic());
88     let panic_str = error.into_panic().downcast::<&'static str>().unwrap();
89     assert_eq!(*panic_str, "Test panic");
90 
91     // Trying again with a "safe" task still works
92     let join_handle = pool.spawn_pinned(|| async { "test" });
93     let result = join_handle.await;
94     assert!(result.is_ok());
95     assert_eq!(result.unwrap(), "test");
96 }
97 
98 /// A panic during task creation causes the join handle to return an error.
99 /// But, you can continue to spawn tasks.
100 #[tokio::test]
101 #[cfg(panic = "unwind")]
callback_panic_does_not_kill_worker()102 async fn callback_panic_does_not_kill_worker() {
103     let pool = task::LocalPoolHandle::new(1);
104 
105     let join_handle = pool.spawn_pinned(|| {
106         panic!("Test panic");
107         #[allow(unreachable_code)]
108         async {}
109     });
110 
111     let result = join_handle.await;
112     assert!(result.is_err());
113     let error = result.unwrap_err();
114     assert!(error.is_panic());
115     let panic_str = error.into_panic().downcast::<&'static str>().unwrap();
116     assert_eq!(*panic_str, "Test panic");
117 
118     // Trying again with a "safe" callback works
119     let join_handle = pool.spawn_pinned(|| async { "test" });
120     let result = join_handle.await;
121     assert!(result.is_ok());
122     assert_eq!(result.unwrap(), "test");
123 }
124 
125 /// Canceling the task via the returned join handle cancels the spawned task
126 /// (which has a different, internal join handle).
127 #[tokio::test]
task_cancellation_propagates()128 async fn task_cancellation_propagates() {
129     let pool = task::LocalPoolHandle::new(1);
130     let notify_dropped = Arc::new(());
131     let weak_notify_dropped = Arc::downgrade(&notify_dropped);
132 
133     let (start_sender, start_receiver) = tokio::sync::oneshot::channel();
134     let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>();
135     let join_handle = pool.spawn_pinned(|| async move {
136         let _drop_sender = drop_sender;
137         // Move the Arc into the task
138         let _notify_dropped = notify_dropped;
139         let _ = start_sender.send(());
140 
141         // Keep the task running until it gets aborted
142         futures::future::pending::<()>().await;
143     });
144 
145     // Wait for the task to start
146     let _ = start_receiver.await;
147 
148     join_handle.abort();
149 
150     // Wait for the inner task to abort, dropping the sender.
151     // The top level join handle aborts quicker than the inner task (the abort
152     // needs to propagate and get processed on the worker thread), so we can't
153     // just await the top level join handle.
154     let _ = drop_receiver.await;
155 
156     // Check that the Arc has been dropped. This verifies that the inner task
157     // was canceled as well.
158     assert!(weak_notify_dropped.upgrade().is_none());
159 }
160 
161 /// Tasks should be given to the least burdened worker. When spawning two tasks
162 /// on a pool with two empty workers the tasks should be spawned on separate
163 /// workers.
164 #[tokio::test]
tasks_are_balanced()165 async fn tasks_are_balanced() {
166     let pool = task::LocalPoolHandle::new(2);
167 
168     // Spawn a task so one thread has a task count of 1
169     let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel();
170     let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel();
171     let join_handle1 = pool.spawn_pinned(|| async move {
172         let _ = start_sender1.send(());
173         let _ = end_receiver1.await;
174         std::thread::current().id()
175     });
176 
177     // Wait for the first task to start up
178     let _ = start_receiver1.await;
179 
180     // This task should be spawned on the other thread
181     let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel();
182     let join_handle2 = pool.spawn_pinned(|| async move {
183         let _ = start_sender2.send(());
184         std::thread::current().id()
185     });
186 
187     // Wait for the second task to start up
188     let _ = start_receiver2.await;
189 
190     // Allow the first task to end
191     let _ = end_sender1.send(());
192 
193     let thread_id1 = join_handle1.await.unwrap();
194     let thread_id2 = join_handle2.await.unwrap();
195 
196     // Since the first task was active when the second task spawned, they should
197     // be on separate workers/threads.
198     assert_ne!(thread_id1, thread_id2);
199 }
200 
201 #[tokio::test]
spawn_by_idx()202 async fn spawn_by_idx() {
203     let pool = task::LocalPoolHandle::new(3);
204     let barrier = Arc::new(Barrier::new(4));
205     let barrier1 = barrier.clone();
206     let barrier2 = barrier.clone();
207     let barrier3 = barrier.clone();
208 
209     let handle1 = pool.spawn_pinned_by_idx(
210         || async move {
211             barrier1.wait().await;
212             std::thread::current().id()
213         },
214         0,
215     );
216     pool.spawn_pinned_by_idx(
217         || async move {
218             barrier2.wait().await;
219             std::thread::current().id()
220         },
221         0,
222     );
223     let handle2 = pool.spawn_pinned_by_idx(
224         || async move {
225             barrier3.wait().await;
226             std::thread::current().id()
227         },
228         1,
229     );
230 
231     let loads = pool.get_task_loads_for_each_worker();
232     barrier.wait().await;
233     assert_eq!(loads[0], 2);
234     assert_eq!(loads[1], 1);
235     assert_eq!(loads[2], 0);
236 
237     let thread_id1 = handle1.await.unwrap();
238     let thread_id2 = handle2.await.unwrap();
239 
240     assert_ne!(thread_id1, thread_id2);
241 }
242