1 use hashbrown::hash_map::RawEntryMut;
2 use hashbrown::HashMap;
3 use std::borrow::Borrow;
4 use std::collections::hash_map::RandomState;
5 use std::fmt;
6 use std::future::Future;
7 use std::hash::{BuildHasher, Hash, Hasher};
8 use std::marker::PhantomData;
9 use tokio::runtime::Handle;
10 use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11 
12 /// A collection of tasks spawned on a Tokio runtime, associated with hash map
13 /// keys.
14 ///
15 /// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16 /// addition of a  set of keys associated with each task. These keys allow
17 /// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18 /// `JoinMap` based on   their keys, or [test whether a task corresponding to a
19 /// given key exists][contains] in the `JoinMap`.
20 ///
21 /// In addition, when tasks in the `JoinMap` complete, they will return the
22 /// associated key along with the value returned by the task, if any.
23 ///
24 /// A `JoinMap` can be used to await the completion of some or all of the tasks
25 /// in the map. The map is not ordered, and the tasks will be returned in the
26 /// order they complete.
27 ///
28 /// All of the tasks must have the same return type `V`.
29 ///
30 /// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31 ///
32 /// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the
33 /// documentation on unstable features][unstable] for details on how to enable
34 /// Tokio's unstable features.
35 ///
36 /// # Examples
37 ///
38 /// Spawn multiple tasks and wait for them:
39 ///
40 /// ```
41 /// use tokio_util::task::JoinMap;
42 ///
43 /// #[tokio::main]
44 /// async fn main() {
45 ///     let mut map = JoinMap::new();
46 ///
47 ///     for i in 0..10 {
48 ///         // Spawn a task on the `JoinMap` with `i` as its key.
49 ///         map.spawn(i, async move { /* ... */ });
50 ///     }
51 ///
52 ///     let mut seen = [false; 10];
53 ///
54 ///     // When a task completes, `join_next` returns the task's key along
55 ///     // with its output.
56 ///     while let Some((key, res)) = map.join_next().await {
57 ///         seen[key] = true;
58 ///         assert!(res.is_ok(), "task {} completed successfully!", key);
59 ///     }
60 ///
61 ///     for i in 0..10 {
62 ///         assert!(seen[i]);
63 ///     }
64 /// }
65 /// ```
66 ///
67 /// Cancel tasks based on their keys:
68 ///
69 /// ```
70 /// use tokio_util::task::JoinMap;
71 ///
72 /// #[tokio::main]
73 /// async fn main() {
74 ///     let mut map = JoinMap::new();
75 ///
76 ///     map.spawn("hello world", async move { /* ... */ });
77 ///     map.spawn("goodbye world", async move { /* ... */});
78 ///
79 ///     // Look up the "goodbye world" task in the map and abort it.
80 ///     let aborted = map.abort("goodbye world");
81 ///
82 ///     // `JoinMap::abort` returns `true` if a task existed for the
83 ///     // provided key.
84 ///     assert!(aborted);
85 ///
86 ///     while let Some((key, res)) = map.join_next().await {
87 ///         if key == "goodbye world" {
88 ///             // The aborted task should complete with a cancelled `JoinError`.
89 ///             assert!(res.unwrap_err().is_cancelled());
90 ///         } else {
91 ///             // Other tasks should complete normally.
92 ///             assert!(res.is_ok());
93 ///         }
94 ///     }
95 /// }
96 /// ```
97 ///
98 /// [`JoinSet`]: tokio::task::JoinSet
99 /// [unstable]: tokio#unstable-features
100 /// [abort]: fn@Self::abort
101 /// [abort_matching]: fn@Self::abort_matching
102 /// [contains]: fn@Self::contains_key
103 #[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
104 pub struct JoinMap<K, V, S = RandomState> {
105     /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
106     /// indexed by their keys and task IDs.
107     ///
108     /// The [`Key`] type contains both the task's `K`-typed key provided when
109     /// spawning tasks, and the task's IDs. The IDs are stored here to resolve
110     /// hash collisions when looking up tasks based on their pre-computed hash
111     /// (as stored in the `hashes_by_task` map).
112     tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
113 
114     /// A map from task IDs to the hash of the key associated with that task.
115     ///
116     /// This map is used to perform reverse lookups of tasks in the
117     /// `tasks_by_key` map based on their task IDs. When a task terminates, the
118     /// ID is provided to us by the `JoinSet`, so we can look up the hash value
119     /// of that task's key, and then remove it from the `tasks_by_key` map using
120     /// the raw hash code, resolving collisions by comparing task IDs.
121     hashes_by_task: HashMap<Id, u64, S>,
122 
123     /// The [`JoinSet`] that awaits the completion of tasks spawned on this
124     /// `JoinMap`.
125     tasks: JoinSet<V>,
126 }
127 
128 /// A [`JoinMap`] key.
129 ///
130 /// This holds both a `K`-typed key (the actual key as seen by the user), _and_
131 /// a task ID, so that hash collisions between `K`-typed keys can be resolved
132 /// using either `K`'s `Eq` impl *or* by checking the task IDs.
133 ///
134 /// This allows looking up a task using either an actual key (such as when the
135 /// user queries the map with a key), *or* using a task ID and a hash (such as
136 /// when removing completed tasks from the map).
137 #[derive(Debug)]
138 struct Key<K> {
139     key: K,
140     id: Id,
141 }
142 
143 impl<K, V> JoinMap<K, V> {
144     /// Creates a new empty `JoinMap`.
145     ///
146     /// The `JoinMap` is initially created with a capacity of 0, so it will not
147     /// allocate until a task is first spawned on it.
148     ///
149     /// # Examples
150     ///
151     /// ```
152     /// use tokio_util::task::JoinMap;
153     /// let map: JoinMap<&str, i32> = JoinMap::new();
154     /// ```
155     #[inline]
156     #[must_use]
new() -> Self157     pub fn new() -> Self {
158         Self::with_hasher(RandomState::new())
159     }
160 
161     /// Creates an empty `JoinMap` with the specified capacity.
162     ///
163     /// The `JoinMap` will be able to hold at least `capacity` tasks without
164     /// reallocating.
165     ///
166     /// # Examples
167     ///
168     /// ```
169     /// use tokio_util::task::JoinMap;
170     /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
171     /// ```
172     #[inline]
173     #[must_use]
with_capacity(capacity: usize) -> Self174     pub fn with_capacity(capacity: usize) -> Self {
175         JoinMap::with_capacity_and_hasher(capacity, Default::default())
176     }
177 }
178 
179 impl<K, V, S: Clone> JoinMap<K, V, S> {
180     /// Creates an empty `JoinMap` which will use the given hash builder to hash
181     /// keys.
182     ///
183     /// The created map has the default initial capacity.
184     ///
185     /// Warning: `hash_builder` is normally randomly generated, and
186     /// is designed to allow `JoinMap` to be resistant to attacks that
187     /// cause many collisions and very poor performance. Setting it
188     /// manually using this function can expose a DoS attack vector.
189     ///
190     /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
191     /// the `JoinMap` to be useful, see its documentation for details.
192     #[inline]
193     #[must_use]
with_hasher(hash_builder: S) -> Self194     pub fn with_hasher(hash_builder: S) -> Self {
195         Self::with_capacity_and_hasher(0, hash_builder)
196     }
197 
198     /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
199     /// to hash the keys.
200     ///
201     /// The `JoinMap` will be able to hold at least `capacity` elements without
202     /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
203     ///
204     /// Warning: `hash_builder` is normally randomly generated, and
205     /// is designed to allow HashMaps to be resistant to attacks that
206     /// cause many collisions and very poor performance. Setting it
207     /// manually using this function can expose a DoS attack vector.
208     ///
209     /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
210     /// the `JoinMap`to be useful, see its documentation for details.
211     ///
212     /// # Examples
213     ///
214     /// ```
215     /// # #[tokio::main]
216     /// # async fn main() {
217     /// use tokio_util::task::JoinMap;
218     /// use std::collections::hash_map::RandomState;
219     ///
220     /// let s = RandomState::new();
221     /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
222     /// map.spawn(1, async move { "hello world!" });
223     /// # }
224     /// ```
225     #[inline]
226     #[must_use]
with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self227     pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
228         Self {
229             tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
230             hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
231             tasks: JoinSet::new(),
232         }
233     }
234 
235     /// Returns the number of tasks currently in the `JoinMap`.
len(&self) -> usize236     pub fn len(&self) -> usize {
237         let len = self.tasks_by_key.len();
238         debug_assert_eq!(len, self.hashes_by_task.len());
239         len
240     }
241 
242     /// Returns whether the `JoinMap` is empty.
is_empty(&self) -> bool243     pub fn is_empty(&self) -> bool {
244         let empty = self.tasks_by_key.is_empty();
245         debug_assert_eq!(empty, self.hashes_by_task.is_empty());
246         empty
247     }
248 
249     /// Returns the number of tasks the map can hold without reallocating.
250     ///
251     /// This number is a lower bound; the `JoinMap` might be able to hold
252     /// more, but is guaranteed to be able to hold at least this many.
253     ///
254     /// # Examples
255     ///
256     /// ```
257     /// use tokio_util::task::JoinMap;
258     ///
259     /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
260     /// assert!(map.capacity() >= 100);
261     /// ```
262     #[inline]
capacity(&self) -> usize263     pub fn capacity(&self) -> usize {
264         let capacity = self.tasks_by_key.capacity();
265         debug_assert_eq!(capacity, self.hashes_by_task.capacity());
266         capacity
267     }
268 }
269 
270 impl<K, V, S> JoinMap<K, V, S>
271 where
272     K: Hash + Eq,
273     V: 'static,
274     S: BuildHasher,
275 {
276     /// Spawn the provided task and store it in this `JoinMap` with the provided
277     /// key.
278     ///
279     /// If a task previously existed in the `JoinMap` for this key, that task
280     /// will be cancelled and replaced with the new one. The previous task will
281     /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
282     /// *not* return a cancelled [`JoinError`] for that task.
283     ///
284     /// # Panics
285     ///
286     /// This method panics if called outside of a Tokio runtime.
287     ///
288     /// [`join_next`]: Self::join_next
289     #[track_caller]
spawn<F>(&mut self, key: K, task: F) where F: Future<Output = V>, F: Send + 'static, V: Send,290     pub fn spawn<F>(&mut self, key: K, task: F)
291     where
292         F: Future<Output = V>,
293         F: Send + 'static,
294         V: Send,
295     {
296         let task = self.tasks.spawn(task);
297         self.insert(key, task)
298     }
299 
300     /// Spawn the provided task on the provided runtime and store it in this
301     /// `JoinMap` with the provided key.
302     ///
303     /// If a task previously existed in the `JoinMap` for this key, that task
304     /// will be cancelled and replaced with the new one. The previous task will
305     /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
306     /// *not* return a cancelled [`JoinError`] for that task.
307     ///
308     /// [`join_next`]: Self::join_next
309     #[track_caller]
spawn_on<F>(&mut self, key: K, task: F, handle: &Handle) where F: Future<Output = V>, F: Send + 'static, V: Send,310     pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
311     where
312         F: Future<Output = V>,
313         F: Send + 'static,
314         V: Send,
315     {
316         let task = self.tasks.spawn_on(task, handle);
317         self.insert(key, task);
318     }
319 
320     /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
321     /// key.
322     ///
323     /// If a task previously existed in the `JoinMap` for this key, that task
324     /// will be cancelled and replaced with the new one. The previous task will
325     /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
326     /// *not* return a cancelled [`JoinError`] for that task.
327     ///
328     /// Note that blocking tasks cannot be cancelled after execution starts.
329     /// Replaced blocking tasks will still run to completion if the task has begun
330     /// to execute when it is replaced. A blocking task which is replaced before
331     /// it has been scheduled on a blocking worker thread will be cancelled.
332     ///
333     /// # Panics
334     ///
335     /// This method panics if called outside of a Tokio runtime.
336     ///
337     /// [`join_next`]: Self::join_next
338     #[track_caller]
spawn_blocking<F>(&mut self, key: K, f: F) where F: FnOnce() -> V, F: Send + 'static, V: Send,339     pub fn spawn_blocking<F>(&mut self, key: K, f: F)
340     where
341         F: FnOnce() -> V,
342         F: Send + 'static,
343         V: Send,
344     {
345         let task = self.tasks.spawn_blocking(f);
346         self.insert(key, task)
347     }
348 
349     /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
350     /// `JoinMap` with the provided key.
351     ///
352     /// If a task previously existed in the `JoinMap` for this key, that task
353     /// will be cancelled and replaced with the new one. The previous task will
354     /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
355     /// *not* return a cancelled [`JoinError`] for that task.
356     ///
357     /// Note that blocking tasks cannot be cancelled after execution starts.
358     /// Replaced blocking tasks will still run to completion if the task has begun
359     /// to execute when it is replaced. A blocking task which is replaced before
360     /// it has been scheduled on a blocking worker thread will be cancelled.
361     ///
362     /// [`join_next`]: Self::join_next
363     #[track_caller]
spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle) where F: FnOnce() -> V, F: Send + 'static, V: Send,364     pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
365     where
366         F: FnOnce() -> V,
367         F: Send + 'static,
368         V: Send,
369     {
370         let task = self.tasks.spawn_blocking_on(f, handle);
371         self.insert(key, task);
372     }
373 
374     /// Spawn the provided task on the current [`LocalSet`] and store it in this
375     /// `JoinMap` with the provided key.
376     ///
377     /// If a task previously existed in the `JoinMap` for this key, that task
378     /// will be cancelled and replaced with the new one. The previous task will
379     /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
380     /// *not* return a cancelled [`JoinError`] for that task.
381     ///
382     /// # Panics
383     ///
384     /// This method panics if it is called outside of a `LocalSet`.
385     ///
386     /// [`LocalSet`]: tokio::task::LocalSet
387     /// [`join_next`]: Self::join_next
388     #[track_caller]
spawn_local<F>(&mut self, key: K, task: F) where F: Future<Output = V>, F: 'static,389     pub fn spawn_local<F>(&mut self, key: K, task: F)
390     where
391         F: Future<Output = V>,
392         F: 'static,
393     {
394         let task = self.tasks.spawn_local(task);
395         self.insert(key, task);
396     }
397 
398     /// Spawn the provided task on the provided [`LocalSet`] and store it in
399     /// this `JoinMap` with the provided key.
400     ///
401     /// If a task previously existed in the `JoinMap` for this key, that task
402     /// will be cancelled and replaced with the new one. The previous task will
403     /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
404     /// *not* return a cancelled [`JoinError`] for that task.
405     ///
406     /// [`LocalSet`]: tokio::task::LocalSet
407     /// [`join_next`]: Self::join_next
408     #[track_caller]
spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet) where F: Future<Output = V>, F: 'static,409     pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
410     where
411         F: Future<Output = V>,
412         F: 'static,
413     {
414         let task = self.tasks.spawn_local_on(task, local_set);
415         self.insert(key, task)
416     }
417 
insert(&mut self, key: K, abort: AbortHandle)418     fn insert(&mut self, key: K, abort: AbortHandle) {
419         let hash = self.hash(&key);
420         let id = abort.id();
421         let map_key = Key { id, key };
422 
423         // Insert the new key into the map of tasks by keys.
424         let entry = self
425             .tasks_by_key
426             .raw_entry_mut()
427             .from_hash(hash, |k| k.key == map_key.key);
428         match entry {
429             RawEntryMut::Occupied(mut occ) => {
430                 // There was a previous task spawned with the same key! Cancel
431                 // that task, and remove its ID from the map of hashes by task IDs.
432                 let Key { id: prev_id, .. } = occ.insert_key(map_key);
433                 occ.insert(abort).abort();
434                 let _prev_hash = self.hashes_by_task.remove(&prev_id);
435                 debug_assert_eq!(Some(hash), _prev_hash);
436             }
437             RawEntryMut::Vacant(vac) => {
438                 vac.insert(map_key, abort);
439             }
440         };
441 
442         // Associate the key's hash with this task's ID, for looking up tasks by ID.
443         let _prev = self.hashes_by_task.insert(id, hash);
444         debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
445     }
446 
447     /// Waits until one of the tasks in the map completes and returns its
448     /// output, along with the key corresponding to that task.
449     ///
450     /// Returns `None` if the map is empty.
451     ///
452     /// # Cancel Safety
453     ///
454     /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
455     /// statement and some other branch completes first, it is guaranteed that no tasks were
456     /// removed from this `JoinMap`.
457     ///
458     /// # Returns
459     ///
460     /// This function returns:
461     ///
462     ///  * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
463     ///    completed. The `value` is the return value of that ask, and `key` is
464     ///    the key associated with the task.
465     ///  * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
466     ///    panicked or been aborted. `key` is the key associated  with the task
467     ///    that panicked or was aborted.
468     ///  * `None` if the `JoinMap` is empty.
469     ///
470     /// [`tokio::select!`]: tokio::select
join_next(&mut self) -> Option<(K, Result<V, JoinError>)>471     pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
472         let (res, id) = match self.tasks.join_next_with_id().await {
473             Some(Ok((id, output))) => (Ok(output), id),
474             Some(Err(e)) => {
475                 let id = e.id();
476                 (Err(e), id)
477             }
478             None => return None,
479         };
480         let key = self.remove_by_id(id)?;
481         Some((key, res))
482     }
483 
484     /// Aborts all tasks and waits for them to finish shutting down.
485     ///
486     /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
487     /// a loop until it returns `None`.
488     ///
489     /// This method ignores any panics in the tasks shutting down. When this call returns, the
490     /// `JoinMap` will be empty.
491     ///
492     /// [`abort_all`]: fn@Self::abort_all
493     /// [`join_next`]: fn@Self::join_next
shutdown(&mut self)494     pub async fn shutdown(&mut self) {
495         self.abort_all();
496         while self.join_next().await.is_some() {}
497     }
498 
499     /// Abort the task corresponding to the provided `key`.
500     ///
501     /// If this `JoinMap` contains a task corresponding to `key`, this method
502     /// will abort that task and return `true`. Otherwise, if no task exists for
503     /// `key`, this method returns `false`.
504     ///
505     /// # Examples
506     ///
507     /// Aborting a task by key:
508     ///
509     /// ```
510     /// use tokio_util::task::JoinMap;
511     ///
512     /// # #[tokio::main]
513     /// # async fn main() {
514     /// let mut map = JoinMap::new();
515     ///
516     /// map.spawn("hello world", async move { /* ... */ });
517     /// map.spawn("goodbye world", async move { /* ... */});
518     ///
519     /// // Look up the "goodbye world" task in the map and abort it.
520     /// map.abort("goodbye world");
521     ///
522     /// while let Some((key, res)) = map.join_next().await {
523     ///     if key == "goodbye world" {
524     ///         // The aborted task should complete with a cancelled `JoinError`.
525     ///         assert!(res.unwrap_err().is_cancelled());
526     ///     } else {
527     ///         // Other tasks should complete normally.
528     ///         assert!(res.is_ok());
529     ///     }
530     /// }
531     /// # }
532     /// ```
533     ///
534     /// `abort` returns `true` if a task was aborted:
535     /// ```
536     /// use tokio_util::task::JoinMap;
537     ///
538     /// # #[tokio::main]
539     /// # async fn main() {
540     /// let mut map = JoinMap::new();
541     ///
542     /// map.spawn("hello world", async move { /* ... */ });
543     /// map.spawn("goodbye world", async move { /* ... */});
544     ///
545     /// // A task for the key "goodbye world" should exist in the map:
546     /// assert!(map.abort("goodbye world"));
547     ///
548     /// // Aborting a key that does not exist will return `false`:
549     /// assert!(!map.abort("goodbye universe"));
550     /// # }
551     /// ```
abort<Q: ?Sized>(&mut self, key: &Q) -> bool where Q: Hash + Eq, K: Borrow<Q>,552     pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool
553     where
554         Q: Hash + Eq,
555         K: Borrow<Q>,
556     {
557         match self.get_by_key(key) {
558             Some((_, handle)) => {
559                 handle.abort();
560                 true
561             }
562             None => false,
563         }
564     }
565 
566     /// Aborts all tasks with keys matching `predicate`.
567     ///
568     /// `predicate` is a function called with a reference to each key in the
569     /// map. If it returns `true` for a given key, the corresponding task will
570     /// be cancelled.
571     ///
572     /// # Examples
573     /// ```
574     /// use tokio_util::task::JoinMap;
575     ///
576     /// # // use the current thread rt so that spawned tasks don't
577     /// # // complete in the background before they can be aborted.
578     /// # #[tokio::main(flavor = "current_thread")]
579     /// # async fn main() {
580     /// let mut map = JoinMap::new();
581     ///
582     /// map.spawn("hello world", async move {
583     ///     // ...
584     ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
585     /// });
586     /// map.spawn("goodbye world", async move {
587     ///     // ...
588     ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
589     /// });
590     /// map.spawn("hello san francisco", async move {
591     ///     // ...
592     ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
593     /// });
594     /// map.spawn("goodbye universe", async move {
595     ///     // ...
596     ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
597     /// });
598     ///
599     /// // Abort all tasks whose keys begin with "goodbye"
600     /// map.abort_matching(|key| key.starts_with("goodbye"));
601     ///
602     /// let mut seen = 0;
603     /// while let Some((key, res)) = map.join_next().await {
604     ///     seen += 1;
605     ///     if key.starts_with("goodbye") {
606     ///         // The aborted task should complete with a cancelled `JoinError`.
607     ///         assert!(res.unwrap_err().is_cancelled());
608     ///     } else {
609     ///         // Other tasks should complete normally.
610     ///         assert!(key.starts_with("hello"));
611     ///         assert!(res.is_ok());
612     ///     }
613     /// }
614     ///
615     /// // All spawned tasks should have completed.
616     /// assert_eq!(seen, 4);
617     /// # }
618     /// ```
abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool)619     pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
620         // Note: this method iterates over the tasks and keys *without* removing
621         // any entries, so that the keys from aborted tasks can still be
622         // returned when calling `join_next` in the future.
623         for (Key { ref key, .. }, task) in &self.tasks_by_key {
624             if predicate(key) {
625                 task.abort();
626             }
627         }
628     }
629 
630     /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
631     ///
632     /// If a task has completed, but its output hasn't yet been consumed by a
633     /// call to [`join_next`], this method will still return its key.
634     ///
635     /// [`join_next`]: fn@Self::join_next
keys(&self) -> JoinMapKeys<'_, K, V>636     pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
637         JoinMapKeys {
638             iter: self.tasks_by_key.keys(),
639             _value: PhantomData,
640         }
641     }
642 
643     /// Returns `true` if this `JoinMap` contains a task for the provided key.
644     ///
645     /// If the task has completed, but its output hasn't yet been consumed by a
646     /// call to [`join_next`], this method will still return `true`.
647     ///
648     /// [`join_next`]: fn@Self::join_next
contains_key<Q: ?Sized>(&self, key: &Q) -> bool where Q: Hash + Eq, K: Borrow<Q>,649     pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
650     where
651         Q: Hash + Eq,
652         K: Borrow<Q>,
653     {
654         self.get_by_key(key).is_some()
655     }
656 
657     /// Returns `true` if this `JoinMap` contains a task with the provided
658     /// [task ID].
659     ///
660     /// If the task has completed, but its output hasn't yet been consumed by a
661     /// call to [`join_next`], this method will still return `true`.
662     ///
663     /// [`join_next`]: fn@Self::join_next
664     /// [task ID]: tokio::task::Id
contains_task(&self, task: &Id) -> bool665     pub fn contains_task(&self, task: &Id) -> bool {
666         self.get_by_id(task).is_some()
667     }
668 
669     /// Reserves capacity for at least `additional` more tasks to be spawned
670     /// on this `JoinMap` without reallocating for the map of task keys. The
671     /// collection may reserve more space to avoid frequent reallocations.
672     ///
673     /// Note that spawning a task will still cause an allocation for the task
674     /// itself.
675     ///
676     /// # Panics
677     ///
678     /// Panics if the new allocation size overflows [`usize`].
679     ///
680     /// # Examples
681     ///
682     /// ```
683     /// use tokio_util::task::JoinMap;
684     ///
685     /// let mut map: JoinMap<&str, i32> = JoinMap::new();
686     /// map.reserve(10);
687     /// ```
688     #[inline]
reserve(&mut self, additional: usize)689     pub fn reserve(&mut self, additional: usize) {
690         self.tasks_by_key.reserve(additional);
691         self.hashes_by_task.reserve(additional);
692     }
693 
694     /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
695     /// down as much as possible while maintaining the internal rules
696     /// and possibly leaving some space in accordance with the resize policy.
697     ///
698     /// # Examples
699     ///
700     /// ```
701     /// # #[tokio::main]
702     /// # async fn main() {
703     /// use tokio_util::task::JoinMap;
704     ///
705     /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
706     /// map.spawn(1, async move { 2 });
707     /// map.spawn(3, async move { 4 });
708     /// assert!(map.capacity() >= 100);
709     /// map.shrink_to_fit();
710     /// assert!(map.capacity() >= 2);
711     /// # }
712     /// ```
713     #[inline]
shrink_to_fit(&mut self)714     pub fn shrink_to_fit(&mut self) {
715         self.hashes_by_task.shrink_to_fit();
716         self.tasks_by_key.shrink_to_fit();
717     }
718 
719     /// Shrinks the capacity of the map with a lower limit. It will drop
720     /// down no lower than the supplied limit while maintaining the internal rules
721     /// and possibly leaving some space in accordance with the resize policy.
722     ///
723     /// If the current capacity is less than the lower limit, this is a no-op.
724     ///
725     /// # Examples
726     ///
727     /// ```
728     /// # #[tokio::main]
729     /// # async fn main() {
730     /// use tokio_util::task::JoinMap;
731     ///
732     /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
733     /// map.spawn(1, async move { 2 });
734     /// map.spawn(3, async move { 4 });
735     /// assert!(map.capacity() >= 100);
736     /// map.shrink_to(10);
737     /// assert!(map.capacity() >= 10);
738     /// map.shrink_to(0);
739     /// assert!(map.capacity() >= 2);
740     /// # }
741     /// ```
742     #[inline]
shrink_to(&mut self, min_capacity: usize)743     pub fn shrink_to(&mut self, min_capacity: usize) {
744         self.hashes_by_task.shrink_to(min_capacity);
745         self.tasks_by_key.shrink_to(min_capacity)
746     }
747 
748     /// Look up a task in the map by its key, returning the key and abort handle.
get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)> where Q: Hash + Eq, K: Borrow<Q>,749     fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
750     where
751         Q: Hash + Eq,
752         K: Borrow<Q>,
753     {
754         let hash = self.hash(key);
755         self.tasks_by_key
756             .raw_entry()
757             .from_hash(hash, |k| k.key.borrow() == key)
758     }
759 
760     /// Look up a task in the map by its task ID, returning the key and abort handle.
get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)>761     fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
762         let hash = self.hashes_by_task.get(id)?;
763         self.tasks_by_key
764             .raw_entry()
765             .from_hash(*hash, |k| &k.id == id)
766     }
767 
768     /// Remove a task from the map by ID, returning the key for that task.
remove_by_id(&mut self, id: Id) -> Option<K>769     fn remove_by_id(&mut self, id: Id) -> Option<K> {
770         // Get the hash for the given ID.
771         let hash = self.hashes_by_task.remove(&id)?;
772 
773         // Remove the entry for that hash.
774         let entry = self
775             .tasks_by_key
776             .raw_entry_mut()
777             .from_hash(hash, |k| k.id == id);
778         let (Key { id: _key_id, key }, handle) = match entry {
779             RawEntryMut::Occupied(entry) => entry.remove_entry(),
780             _ => return None,
781         };
782         debug_assert_eq!(_key_id, id);
783         debug_assert_eq!(id, handle.id());
784         self.hashes_by_task.remove(&id);
785         Some(key)
786     }
787 
788     /// Returns the hash for a given key.
789     #[inline]
hash<Q: ?Sized>(&self, key: &Q) -> u64 where Q: Hash,790     fn hash<Q: ?Sized>(&self, key: &Q) -> u64
791     where
792         Q: Hash,
793     {
794         let mut hasher = self.tasks_by_key.hasher().build_hasher();
795         key.hash(&mut hasher);
796         hasher.finish()
797     }
798 }
799 
800 impl<K, V, S> JoinMap<K, V, S>
801 where
802     V: 'static,
803 {
804     /// Aborts all tasks on this `JoinMap`.
805     ///
806     /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
807     /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
abort_all(&mut self)808     pub fn abort_all(&mut self) {
809         self.tasks.abort_all()
810     }
811 
812     /// Removes all tasks from this `JoinMap` without aborting them.
813     ///
814     /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
815     /// is dropped. They may still be aborted by key.
detach_all(&mut self)816     pub fn detach_all(&mut self) {
817         self.tasks.detach_all();
818         self.tasks_by_key.clear();
819         self.hashes_by_task.clear();
820     }
821 }
822 
823 // Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
824 // Debug`, since no value is ever actually stored in the map.
825 impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result826     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
827         // format the task keys and abort handles a little nicer by just
828         // printing the key and task ID pairs, without format the `Key` struct
829         // itself or the `AbortHandle`, which would just format the task's ID
830         // again.
831         struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
832         impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
833             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
834                 f.debug_map()
835                     .entries(self.0.keys().map(|Key { key, id }| (key, id)))
836                     .finish()
837             }
838         }
839 
840         f.debug_struct("JoinMap")
841             // The `tasks_by_key` map is the only one that contains information
842             // that's really worth formatting for the user, since it contains
843             // the tasks' keys and IDs. The other fields are basically
844             // implementation details.
845             .field("tasks", &KeySet(&self.tasks_by_key))
846             .finish()
847     }
848 }
849 
850 impl<K, V> Default for JoinMap<K, V> {
default() -> Self851     fn default() -> Self {
852         Self::new()
853     }
854 }
855 
856 // === impl Key ===
857 
858 impl<K: Hash> Hash for Key<K> {
859     // Don't include the task ID in the hash.
860     #[inline]
hash<H: Hasher>(&self, hasher: &mut H)861     fn hash<H: Hasher>(&self, hasher: &mut H) {
862         self.key.hash(hasher);
863     }
864 }
865 
866 // Because we override `Hash` for this type, we must also override the
867 // `PartialEq` impl, so that all instances with the same hash are equal.
868 impl<K: PartialEq> PartialEq for Key<K> {
869     #[inline]
eq(&self, other: &Self) -> bool870     fn eq(&self, other: &Self) -> bool {
871         self.key == other.key
872     }
873 }
874 
875 impl<K: Eq> Eq for Key<K> {}
876 
877 /// An iterator over the keys of a [`JoinMap`].
878 #[derive(Debug, Clone)]
879 pub struct JoinMapKeys<'a, K, V> {
880     iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
881     /// To make it easier to change `JoinMap` in the future, keep V as a generic
882     /// parameter.
883     _value: PhantomData<&'a V>,
884 }
885 
886 impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
887     type Item = &'a K;
888 
next(&mut self) -> Option<&'a K>889     fn next(&mut self) -> Option<&'a K> {
890         self.iter.next().map(|key| &key.key)
891     }
892 
size_hint(&self) -> (usize, Option<usize>)893     fn size_hint(&self) -> (usize, Option<usize>) {
894         self.iter.size_hint()
895     }
896 }
897 
898 impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
len(&self) -> usize899     fn len(&self) -> usize {
900         self.iter.len()
901     }
902 }
903 
904 impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}
905