1 use hashbrown::hash_map::RawEntryMut; 2 use hashbrown::HashMap; 3 use std::borrow::Borrow; 4 use std::collections::hash_map::RandomState; 5 use std::fmt; 6 use std::future::Future; 7 use std::hash::{BuildHasher, Hash, Hasher}; 8 use std::marker::PhantomData; 9 use tokio::runtime::Handle; 10 use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet}; 11 12 /// A collection of tasks spawned on a Tokio runtime, associated with hash map 13 /// keys. 14 /// 15 /// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the 16 /// addition of a set of keys associated with each task. These keys allow 17 /// [cancelling a task][abort] or [multiple tasks][abort_matching] in the 18 /// `JoinMap` based on their keys, or [test whether a task corresponding to a 19 /// given key exists][contains] in the `JoinMap`. 20 /// 21 /// In addition, when tasks in the `JoinMap` complete, they will return the 22 /// associated key along with the value returned by the task, if any. 23 /// 24 /// A `JoinMap` can be used to await the completion of some or all of the tasks 25 /// in the map. The map is not ordered, and the tasks will be returned in the 26 /// order they complete. 27 /// 28 /// All of the tasks must have the same return type `V`. 29 /// 30 /// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted. 31 /// 32 /// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the 33 /// documentation on unstable features][unstable] for details on how to enable 34 /// Tokio's unstable features. 35 /// 36 /// # Examples 37 /// 38 /// Spawn multiple tasks and wait for them: 39 /// 40 /// ``` 41 /// use tokio_util::task::JoinMap; 42 /// 43 /// #[tokio::main] 44 /// async fn main() { 45 /// let mut map = JoinMap::new(); 46 /// 47 /// for i in 0..10 { 48 /// // Spawn a task on the `JoinMap` with `i` as its key. 49 /// map.spawn(i, async move { /* ... */ }); 50 /// } 51 /// 52 /// let mut seen = [false; 10]; 53 /// 54 /// // When a task completes, `join_next` returns the task's key along 55 /// // with its output. 56 /// while let Some((key, res)) = map.join_next().await { 57 /// seen[key] = true; 58 /// assert!(res.is_ok(), "task {} completed successfully!", key); 59 /// } 60 /// 61 /// for i in 0..10 { 62 /// assert!(seen[i]); 63 /// } 64 /// } 65 /// ``` 66 /// 67 /// Cancel tasks based on their keys: 68 /// 69 /// ``` 70 /// use tokio_util::task::JoinMap; 71 /// 72 /// #[tokio::main] 73 /// async fn main() { 74 /// let mut map = JoinMap::new(); 75 /// 76 /// map.spawn("hello world", async move { /* ... */ }); 77 /// map.spawn("goodbye world", async move { /* ... */}); 78 /// 79 /// // Look up the "goodbye world" task in the map and abort it. 80 /// let aborted = map.abort("goodbye world"); 81 /// 82 /// // `JoinMap::abort` returns `true` if a task existed for the 83 /// // provided key. 84 /// assert!(aborted); 85 /// 86 /// while let Some((key, res)) = map.join_next().await { 87 /// if key == "goodbye world" { 88 /// // The aborted task should complete with a cancelled `JoinError`. 89 /// assert!(res.unwrap_err().is_cancelled()); 90 /// } else { 91 /// // Other tasks should complete normally. 92 /// assert!(res.is_ok()); 93 /// } 94 /// } 95 /// } 96 /// ``` 97 /// 98 /// [`JoinSet`]: tokio::task::JoinSet 99 /// [unstable]: tokio#unstable-features 100 /// [abort]: fn@Self::abort 101 /// [abort_matching]: fn@Self::abort_matching 102 /// [contains]: fn@Self::contains_key 103 #[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] 104 pub struct JoinMap<K, V, S = RandomState> { 105 /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`, 106 /// indexed by their keys and task IDs. 107 /// 108 /// The [`Key`] type contains both the task's `K`-typed key provided when 109 /// spawning tasks, and the task's IDs. The IDs are stored here to resolve 110 /// hash collisions when looking up tasks based on their pre-computed hash 111 /// (as stored in the `hashes_by_task` map). 112 tasks_by_key: HashMap<Key<K>, AbortHandle, S>, 113 114 /// A map from task IDs to the hash of the key associated with that task. 115 /// 116 /// This map is used to perform reverse lookups of tasks in the 117 /// `tasks_by_key` map based on their task IDs. When a task terminates, the 118 /// ID is provided to us by the `JoinSet`, so we can look up the hash value 119 /// of that task's key, and then remove it from the `tasks_by_key` map using 120 /// the raw hash code, resolving collisions by comparing task IDs. 121 hashes_by_task: HashMap<Id, u64, S>, 122 123 /// The [`JoinSet`] that awaits the completion of tasks spawned on this 124 /// `JoinMap`. 125 tasks: JoinSet<V>, 126 } 127 128 /// A [`JoinMap`] key. 129 /// 130 /// This holds both a `K`-typed key (the actual key as seen by the user), _and_ 131 /// a task ID, so that hash collisions between `K`-typed keys can be resolved 132 /// using either `K`'s `Eq` impl *or* by checking the task IDs. 133 /// 134 /// This allows looking up a task using either an actual key (such as when the 135 /// user queries the map with a key), *or* using a task ID and a hash (such as 136 /// when removing completed tasks from the map). 137 #[derive(Debug)] 138 struct Key<K> { 139 key: K, 140 id: Id, 141 } 142 143 impl<K, V> JoinMap<K, V> { 144 /// Creates a new empty `JoinMap`. 145 /// 146 /// The `JoinMap` is initially created with a capacity of 0, so it will not 147 /// allocate until a task is first spawned on it. 148 /// 149 /// # Examples 150 /// 151 /// ``` 152 /// use tokio_util::task::JoinMap; 153 /// let map: JoinMap<&str, i32> = JoinMap::new(); 154 /// ``` 155 #[inline] 156 #[must_use] new() -> Self157 pub fn new() -> Self { 158 Self::with_hasher(RandomState::new()) 159 } 160 161 /// Creates an empty `JoinMap` with the specified capacity. 162 /// 163 /// The `JoinMap` will be able to hold at least `capacity` tasks without 164 /// reallocating. 165 /// 166 /// # Examples 167 /// 168 /// ``` 169 /// use tokio_util::task::JoinMap; 170 /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10); 171 /// ``` 172 #[inline] 173 #[must_use] with_capacity(capacity: usize) -> Self174 pub fn with_capacity(capacity: usize) -> Self { 175 JoinMap::with_capacity_and_hasher(capacity, Default::default()) 176 } 177 } 178 179 impl<K, V, S: Clone> JoinMap<K, V, S> { 180 /// Creates an empty `JoinMap` which will use the given hash builder to hash 181 /// keys. 182 /// 183 /// The created map has the default initial capacity. 184 /// 185 /// Warning: `hash_builder` is normally randomly generated, and 186 /// is designed to allow `JoinMap` to be resistant to attacks that 187 /// cause many collisions and very poor performance. Setting it 188 /// manually using this function can expose a DoS attack vector. 189 /// 190 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for 191 /// the `JoinMap` to be useful, see its documentation for details. 192 #[inline] 193 #[must_use] with_hasher(hash_builder: S) -> Self194 pub fn with_hasher(hash_builder: S) -> Self { 195 Self::with_capacity_and_hasher(0, hash_builder) 196 } 197 198 /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder` 199 /// to hash the keys. 200 /// 201 /// The `JoinMap` will be able to hold at least `capacity` elements without 202 /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate. 203 /// 204 /// Warning: `hash_builder` is normally randomly generated, and 205 /// is designed to allow HashMaps to be resistant to attacks that 206 /// cause many collisions and very poor performance. Setting it 207 /// manually using this function can expose a DoS attack vector. 208 /// 209 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for 210 /// the `JoinMap`to be useful, see its documentation for details. 211 /// 212 /// # Examples 213 /// 214 /// ``` 215 /// # #[tokio::main] 216 /// # async fn main() { 217 /// use tokio_util::task::JoinMap; 218 /// use std::collections::hash_map::RandomState; 219 /// 220 /// let s = RandomState::new(); 221 /// let mut map = JoinMap::with_capacity_and_hasher(10, s); 222 /// map.spawn(1, async move { "hello world!" }); 223 /// # } 224 /// ``` 225 #[inline] 226 #[must_use] with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self227 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self { 228 Self { 229 tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()), 230 hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder), 231 tasks: JoinSet::new(), 232 } 233 } 234 235 /// Returns the number of tasks currently in the `JoinMap`. len(&self) -> usize236 pub fn len(&self) -> usize { 237 let len = self.tasks_by_key.len(); 238 debug_assert_eq!(len, self.hashes_by_task.len()); 239 len 240 } 241 242 /// Returns whether the `JoinMap` is empty. is_empty(&self) -> bool243 pub fn is_empty(&self) -> bool { 244 let empty = self.tasks_by_key.is_empty(); 245 debug_assert_eq!(empty, self.hashes_by_task.is_empty()); 246 empty 247 } 248 249 /// Returns the number of tasks the map can hold without reallocating. 250 /// 251 /// This number is a lower bound; the `JoinMap` might be able to hold 252 /// more, but is guaranteed to be able to hold at least this many. 253 /// 254 /// # Examples 255 /// 256 /// ``` 257 /// use tokio_util::task::JoinMap; 258 /// 259 /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100); 260 /// assert!(map.capacity() >= 100); 261 /// ``` 262 #[inline] capacity(&self) -> usize263 pub fn capacity(&self) -> usize { 264 let capacity = self.tasks_by_key.capacity(); 265 debug_assert_eq!(capacity, self.hashes_by_task.capacity()); 266 capacity 267 } 268 } 269 270 impl<K, V, S> JoinMap<K, V, S> 271 where 272 K: Hash + Eq, 273 V: 'static, 274 S: BuildHasher, 275 { 276 /// Spawn the provided task and store it in this `JoinMap` with the provided 277 /// key. 278 /// 279 /// If a task previously existed in the `JoinMap` for this key, that task 280 /// will be cancelled and replaced with the new one. The previous task will 281 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will 282 /// *not* return a cancelled [`JoinError`] for that task. 283 /// 284 /// # Panics 285 /// 286 /// This method panics if called outside of a Tokio runtime. 287 /// 288 /// [`join_next`]: Self::join_next 289 #[track_caller] spawn<F>(&mut self, key: K, task: F) where F: Future<Output = V>, F: Send + 'static, V: Send,290 pub fn spawn<F>(&mut self, key: K, task: F) 291 where 292 F: Future<Output = V>, 293 F: Send + 'static, 294 V: Send, 295 { 296 let task = self.tasks.spawn(task); 297 self.insert(key, task) 298 } 299 300 /// Spawn the provided task on the provided runtime and store it in this 301 /// `JoinMap` with the provided key. 302 /// 303 /// If a task previously existed in the `JoinMap` for this key, that task 304 /// will be cancelled and replaced with the new one. The previous task will 305 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will 306 /// *not* return a cancelled [`JoinError`] for that task. 307 /// 308 /// [`join_next`]: Self::join_next 309 #[track_caller] spawn_on<F>(&mut self, key: K, task: F, handle: &Handle) where F: Future<Output = V>, F: Send + 'static, V: Send,310 pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle) 311 where 312 F: Future<Output = V>, 313 F: Send + 'static, 314 V: Send, 315 { 316 let task = self.tasks.spawn_on(task, handle); 317 self.insert(key, task); 318 } 319 320 /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided 321 /// key. 322 /// 323 /// If a task previously existed in the `JoinMap` for this key, that task 324 /// will be cancelled and replaced with the new one. The previous task will 325 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will 326 /// *not* return a cancelled [`JoinError`] for that task. 327 /// 328 /// Note that blocking tasks cannot be cancelled after execution starts. 329 /// Replaced blocking tasks will still run to completion if the task has begun 330 /// to execute when it is replaced. A blocking task which is replaced before 331 /// it has been scheduled on a blocking worker thread will be cancelled. 332 /// 333 /// # Panics 334 /// 335 /// This method panics if called outside of a Tokio runtime. 336 /// 337 /// [`join_next`]: Self::join_next 338 #[track_caller] spawn_blocking<F>(&mut self, key: K, f: F) where F: FnOnce() -> V, F: Send + 'static, V: Send,339 pub fn spawn_blocking<F>(&mut self, key: K, f: F) 340 where 341 F: FnOnce() -> V, 342 F: Send + 'static, 343 V: Send, 344 { 345 let task = self.tasks.spawn_blocking(f); 346 self.insert(key, task) 347 } 348 349 /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this 350 /// `JoinMap` with the provided key. 351 /// 352 /// If a task previously existed in the `JoinMap` for this key, that task 353 /// will be cancelled and replaced with the new one. The previous task will 354 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will 355 /// *not* return a cancelled [`JoinError`] for that task. 356 /// 357 /// Note that blocking tasks cannot be cancelled after execution starts. 358 /// Replaced blocking tasks will still run to completion if the task has begun 359 /// to execute when it is replaced. A blocking task which is replaced before 360 /// it has been scheduled on a blocking worker thread will be cancelled. 361 /// 362 /// [`join_next`]: Self::join_next 363 #[track_caller] spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle) where F: FnOnce() -> V, F: Send + 'static, V: Send,364 pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle) 365 where 366 F: FnOnce() -> V, 367 F: Send + 'static, 368 V: Send, 369 { 370 let task = self.tasks.spawn_blocking_on(f, handle); 371 self.insert(key, task); 372 } 373 374 /// Spawn the provided task on the current [`LocalSet`] and store it in this 375 /// `JoinMap` with the provided key. 376 /// 377 /// If a task previously existed in the `JoinMap` for this key, that task 378 /// will be cancelled and replaced with the new one. The previous task will 379 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will 380 /// *not* return a cancelled [`JoinError`] for that task. 381 /// 382 /// # Panics 383 /// 384 /// This method panics if it is called outside of a `LocalSet`. 385 /// 386 /// [`LocalSet`]: tokio::task::LocalSet 387 /// [`join_next`]: Self::join_next 388 #[track_caller] spawn_local<F>(&mut self, key: K, task: F) where F: Future<Output = V>, F: 'static,389 pub fn spawn_local<F>(&mut self, key: K, task: F) 390 where 391 F: Future<Output = V>, 392 F: 'static, 393 { 394 let task = self.tasks.spawn_local(task); 395 self.insert(key, task); 396 } 397 398 /// Spawn the provided task on the provided [`LocalSet`] and store it in 399 /// this `JoinMap` with the provided key. 400 /// 401 /// If a task previously existed in the `JoinMap` for this key, that task 402 /// will be cancelled and replaced with the new one. The previous task will 403 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will 404 /// *not* return a cancelled [`JoinError`] for that task. 405 /// 406 /// [`LocalSet`]: tokio::task::LocalSet 407 /// [`join_next`]: Self::join_next 408 #[track_caller] spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet) where F: Future<Output = V>, F: 'static,409 pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet) 410 where 411 F: Future<Output = V>, 412 F: 'static, 413 { 414 let task = self.tasks.spawn_local_on(task, local_set); 415 self.insert(key, task) 416 } 417 insert(&mut self, key: K, abort: AbortHandle)418 fn insert(&mut self, key: K, abort: AbortHandle) { 419 let hash = self.hash(&key); 420 let id = abort.id(); 421 let map_key = Key { id, key }; 422 423 // Insert the new key into the map of tasks by keys. 424 let entry = self 425 .tasks_by_key 426 .raw_entry_mut() 427 .from_hash(hash, |k| k.key == map_key.key); 428 match entry { 429 RawEntryMut::Occupied(mut occ) => { 430 // There was a previous task spawned with the same key! Cancel 431 // that task, and remove its ID from the map of hashes by task IDs. 432 let Key { id: prev_id, .. } = occ.insert_key(map_key); 433 occ.insert(abort).abort(); 434 let _prev_hash = self.hashes_by_task.remove(&prev_id); 435 debug_assert_eq!(Some(hash), _prev_hash); 436 } 437 RawEntryMut::Vacant(vac) => { 438 vac.insert(map_key, abort); 439 } 440 }; 441 442 // Associate the key's hash with this task's ID, for looking up tasks by ID. 443 let _prev = self.hashes_by_task.insert(id, hash); 444 debug_assert!(_prev.is_none(), "no prior task should have had the same ID"); 445 } 446 447 /// Waits until one of the tasks in the map completes and returns its 448 /// output, along with the key corresponding to that task. 449 /// 450 /// Returns `None` if the map is empty. 451 /// 452 /// # Cancel Safety 453 /// 454 /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`] 455 /// statement and some other branch completes first, it is guaranteed that no tasks were 456 /// removed from this `JoinMap`. 457 /// 458 /// # Returns 459 /// 460 /// This function returns: 461 /// 462 /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has 463 /// completed. The `value` is the return value of that ask, and `key` is 464 /// the key associated with the task. 465 /// * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has 466 /// panicked or been aborted. `key` is the key associated with the task 467 /// that panicked or was aborted. 468 /// * `None` if the `JoinMap` is empty. 469 /// 470 /// [`tokio::select!`]: tokio::select join_next(&mut self) -> Option<(K, Result<V, JoinError>)>471 pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> { 472 let (res, id) = match self.tasks.join_next_with_id().await { 473 Some(Ok((id, output))) => (Ok(output), id), 474 Some(Err(e)) => { 475 let id = e.id(); 476 (Err(e), id) 477 } 478 None => return None, 479 }; 480 let key = self.remove_by_id(id)?; 481 Some((key, res)) 482 } 483 484 /// Aborts all tasks and waits for them to finish shutting down. 485 /// 486 /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in 487 /// a loop until it returns `None`. 488 /// 489 /// This method ignores any panics in the tasks shutting down. When this call returns, the 490 /// `JoinMap` will be empty. 491 /// 492 /// [`abort_all`]: fn@Self::abort_all 493 /// [`join_next`]: fn@Self::join_next shutdown(&mut self)494 pub async fn shutdown(&mut self) { 495 self.abort_all(); 496 while self.join_next().await.is_some() {} 497 } 498 499 /// Abort the task corresponding to the provided `key`. 500 /// 501 /// If this `JoinMap` contains a task corresponding to `key`, this method 502 /// will abort that task and return `true`. Otherwise, if no task exists for 503 /// `key`, this method returns `false`. 504 /// 505 /// # Examples 506 /// 507 /// Aborting a task by key: 508 /// 509 /// ``` 510 /// use tokio_util::task::JoinMap; 511 /// 512 /// # #[tokio::main] 513 /// # async fn main() { 514 /// let mut map = JoinMap::new(); 515 /// 516 /// map.spawn("hello world", async move { /* ... */ }); 517 /// map.spawn("goodbye world", async move { /* ... */}); 518 /// 519 /// // Look up the "goodbye world" task in the map and abort it. 520 /// map.abort("goodbye world"); 521 /// 522 /// while let Some((key, res)) = map.join_next().await { 523 /// if key == "goodbye world" { 524 /// // The aborted task should complete with a cancelled `JoinError`. 525 /// assert!(res.unwrap_err().is_cancelled()); 526 /// } else { 527 /// // Other tasks should complete normally. 528 /// assert!(res.is_ok()); 529 /// } 530 /// } 531 /// # } 532 /// ``` 533 /// 534 /// `abort` returns `true` if a task was aborted: 535 /// ``` 536 /// use tokio_util::task::JoinMap; 537 /// 538 /// # #[tokio::main] 539 /// # async fn main() { 540 /// let mut map = JoinMap::new(); 541 /// 542 /// map.spawn("hello world", async move { /* ... */ }); 543 /// map.spawn("goodbye world", async move { /* ... */}); 544 /// 545 /// // A task for the key "goodbye world" should exist in the map: 546 /// assert!(map.abort("goodbye world")); 547 /// 548 /// // Aborting a key that does not exist will return `false`: 549 /// assert!(!map.abort("goodbye universe")); 550 /// # } 551 /// ``` abort<Q: ?Sized>(&mut self, key: &Q) -> bool where Q: Hash + Eq, K: Borrow<Q>,552 pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool 553 where 554 Q: Hash + Eq, 555 K: Borrow<Q>, 556 { 557 match self.get_by_key(key) { 558 Some((_, handle)) => { 559 handle.abort(); 560 true 561 } 562 None => false, 563 } 564 } 565 566 /// Aborts all tasks with keys matching `predicate`. 567 /// 568 /// `predicate` is a function called with a reference to each key in the 569 /// map. If it returns `true` for a given key, the corresponding task will 570 /// be cancelled. 571 /// 572 /// # Examples 573 /// ``` 574 /// use tokio_util::task::JoinMap; 575 /// 576 /// # // use the current thread rt so that spawned tasks don't 577 /// # // complete in the background before they can be aborted. 578 /// # #[tokio::main(flavor = "current_thread")] 579 /// # async fn main() { 580 /// let mut map = JoinMap::new(); 581 /// 582 /// map.spawn("hello world", async move { 583 /// // ... 584 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! 585 /// }); 586 /// map.spawn("goodbye world", async move { 587 /// // ... 588 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! 589 /// }); 590 /// map.spawn("hello san francisco", async move { 591 /// // ... 592 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! 593 /// }); 594 /// map.spawn("goodbye universe", async move { 595 /// // ... 596 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! 597 /// }); 598 /// 599 /// // Abort all tasks whose keys begin with "goodbye" 600 /// map.abort_matching(|key| key.starts_with("goodbye")); 601 /// 602 /// let mut seen = 0; 603 /// while let Some((key, res)) = map.join_next().await { 604 /// seen += 1; 605 /// if key.starts_with("goodbye") { 606 /// // The aborted task should complete with a cancelled `JoinError`. 607 /// assert!(res.unwrap_err().is_cancelled()); 608 /// } else { 609 /// // Other tasks should complete normally. 610 /// assert!(key.starts_with("hello")); 611 /// assert!(res.is_ok()); 612 /// } 613 /// } 614 /// 615 /// // All spawned tasks should have completed. 616 /// assert_eq!(seen, 4); 617 /// # } 618 /// ``` abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool)619 pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) { 620 // Note: this method iterates over the tasks and keys *without* removing 621 // any entries, so that the keys from aborted tasks can still be 622 // returned when calling `join_next` in the future. 623 for (Key { ref key, .. }, task) in &self.tasks_by_key { 624 if predicate(key) { 625 task.abort(); 626 } 627 } 628 } 629 630 /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order. 631 /// 632 /// If a task has completed, but its output hasn't yet been consumed by a 633 /// call to [`join_next`], this method will still return its key. 634 /// 635 /// [`join_next`]: fn@Self::join_next keys(&self) -> JoinMapKeys<'_, K, V>636 pub fn keys(&self) -> JoinMapKeys<'_, K, V> { 637 JoinMapKeys { 638 iter: self.tasks_by_key.keys(), 639 _value: PhantomData, 640 } 641 } 642 643 /// Returns `true` if this `JoinMap` contains a task for the provided key. 644 /// 645 /// If the task has completed, but its output hasn't yet been consumed by a 646 /// call to [`join_next`], this method will still return `true`. 647 /// 648 /// [`join_next`]: fn@Self::join_next contains_key<Q: ?Sized>(&self, key: &Q) -> bool where Q: Hash + Eq, K: Borrow<Q>,649 pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool 650 where 651 Q: Hash + Eq, 652 K: Borrow<Q>, 653 { 654 self.get_by_key(key).is_some() 655 } 656 657 /// Returns `true` if this `JoinMap` contains a task with the provided 658 /// [task ID]. 659 /// 660 /// If the task has completed, but its output hasn't yet been consumed by a 661 /// call to [`join_next`], this method will still return `true`. 662 /// 663 /// [`join_next`]: fn@Self::join_next 664 /// [task ID]: tokio::task::Id contains_task(&self, task: &Id) -> bool665 pub fn contains_task(&self, task: &Id) -> bool { 666 self.get_by_id(task).is_some() 667 } 668 669 /// Reserves capacity for at least `additional` more tasks to be spawned 670 /// on this `JoinMap` without reallocating for the map of task keys. The 671 /// collection may reserve more space to avoid frequent reallocations. 672 /// 673 /// Note that spawning a task will still cause an allocation for the task 674 /// itself. 675 /// 676 /// # Panics 677 /// 678 /// Panics if the new allocation size overflows [`usize`]. 679 /// 680 /// # Examples 681 /// 682 /// ``` 683 /// use tokio_util::task::JoinMap; 684 /// 685 /// let mut map: JoinMap<&str, i32> = JoinMap::new(); 686 /// map.reserve(10); 687 /// ``` 688 #[inline] reserve(&mut self, additional: usize)689 pub fn reserve(&mut self, additional: usize) { 690 self.tasks_by_key.reserve(additional); 691 self.hashes_by_task.reserve(additional); 692 } 693 694 /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop 695 /// down as much as possible while maintaining the internal rules 696 /// and possibly leaving some space in accordance with the resize policy. 697 /// 698 /// # Examples 699 /// 700 /// ``` 701 /// # #[tokio::main] 702 /// # async fn main() { 703 /// use tokio_util::task::JoinMap; 704 /// 705 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100); 706 /// map.spawn(1, async move { 2 }); 707 /// map.spawn(3, async move { 4 }); 708 /// assert!(map.capacity() >= 100); 709 /// map.shrink_to_fit(); 710 /// assert!(map.capacity() >= 2); 711 /// # } 712 /// ``` 713 #[inline] shrink_to_fit(&mut self)714 pub fn shrink_to_fit(&mut self) { 715 self.hashes_by_task.shrink_to_fit(); 716 self.tasks_by_key.shrink_to_fit(); 717 } 718 719 /// Shrinks the capacity of the map with a lower limit. It will drop 720 /// down no lower than the supplied limit while maintaining the internal rules 721 /// and possibly leaving some space in accordance with the resize policy. 722 /// 723 /// If the current capacity is less than the lower limit, this is a no-op. 724 /// 725 /// # Examples 726 /// 727 /// ``` 728 /// # #[tokio::main] 729 /// # async fn main() { 730 /// use tokio_util::task::JoinMap; 731 /// 732 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100); 733 /// map.spawn(1, async move { 2 }); 734 /// map.spawn(3, async move { 4 }); 735 /// assert!(map.capacity() >= 100); 736 /// map.shrink_to(10); 737 /// assert!(map.capacity() >= 10); 738 /// map.shrink_to(0); 739 /// assert!(map.capacity() >= 2); 740 /// # } 741 /// ``` 742 #[inline] shrink_to(&mut self, min_capacity: usize)743 pub fn shrink_to(&mut self, min_capacity: usize) { 744 self.hashes_by_task.shrink_to(min_capacity); 745 self.tasks_by_key.shrink_to(min_capacity) 746 } 747 748 /// Look up a task in the map by its key, returning the key and abort handle. get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)> where Q: Hash + Eq, K: Borrow<Q>,749 fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)> 750 where 751 Q: Hash + Eq, 752 K: Borrow<Q>, 753 { 754 let hash = self.hash(key); 755 self.tasks_by_key 756 .raw_entry() 757 .from_hash(hash, |k| k.key.borrow() == key) 758 } 759 760 /// Look up a task in the map by its task ID, returning the key and abort handle. get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)>761 fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> { 762 let hash = self.hashes_by_task.get(id)?; 763 self.tasks_by_key 764 .raw_entry() 765 .from_hash(*hash, |k| &k.id == id) 766 } 767 768 /// Remove a task from the map by ID, returning the key for that task. remove_by_id(&mut self, id: Id) -> Option<K>769 fn remove_by_id(&mut self, id: Id) -> Option<K> { 770 // Get the hash for the given ID. 771 let hash = self.hashes_by_task.remove(&id)?; 772 773 // Remove the entry for that hash. 774 let entry = self 775 .tasks_by_key 776 .raw_entry_mut() 777 .from_hash(hash, |k| k.id == id); 778 let (Key { id: _key_id, key }, handle) = match entry { 779 RawEntryMut::Occupied(entry) => entry.remove_entry(), 780 _ => return None, 781 }; 782 debug_assert_eq!(_key_id, id); 783 debug_assert_eq!(id, handle.id()); 784 self.hashes_by_task.remove(&id); 785 Some(key) 786 } 787 788 /// Returns the hash for a given key. 789 #[inline] hash<Q: ?Sized>(&self, key: &Q) -> u64 where Q: Hash,790 fn hash<Q: ?Sized>(&self, key: &Q) -> u64 791 where 792 Q: Hash, 793 { 794 let mut hasher = self.tasks_by_key.hasher().build_hasher(); 795 key.hash(&mut hasher); 796 hasher.finish() 797 } 798 } 799 800 impl<K, V, S> JoinMap<K, V, S> 801 where 802 V: 'static, 803 { 804 /// Aborts all tasks on this `JoinMap`. 805 /// 806 /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete 807 /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty. abort_all(&mut self)808 pub fn abort_all(&mut self) { 809 self.tasks.abort_all() 810 } 811 812 /// Removes all tasks from this `JoinMap` without aborting them. 813 /// 814 /// The tasks removed by this call will continue to run in the background even if the `JoinMap` 815 /// is dropped. They may still be aborted by key. detach_all(&mut self)816 pub fn detach_all(&mut self) { 817 self.tasks.detach_all(); 818 self.tasks_by_key.clear(); 819 self.hashes_by_task.clear(); 820 } 821 } 822 823 // Hand-written `fmt::Debug` implementation in order to avoid requiring `V: 824 // Debug`, since no value is ever actually stored in the map. 825 impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result826 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 827 // format the task keys and abort handles a little nicer by just 828 // printing the key and task ID pairs, without format the `Key` struct 829 // itself or the `AbortHandle`, which would just format the task's ID 830 // again. 831 struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>); 832 impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> { 833 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 834 f.debug_map() 835 .entries(self.0.keys().map(|Key { key, id }| (key, id))) 836 .finish() 837 } 838 } 839 840 f.debug_struct("JoinMap") 841 // The `tasks_by_key` map is the only one that contains information 842 // that's really worth formatting for the user, since it contains 843 // the tasks' keys and IDs. The other fields are basically 844 // implementation details. 845 .field("tasks", &KeySet(&self.tasks_by_key)) 846 .finish() 847 } 848 } 849 850 impl<K, V> Default for JoinMap<K, V> { default() -> Self851 fn default() -> Self { 852 Self::new() 853 } 854 } 855 856 // === impl Key === 857 858 impl<K: Hash> Hash for Key<K> { 859 // Don't include the task ID in the hash. 860 #[inline] hash<H: Hasher>(&self, hasher: &mut H)861 fn hash<H: Hasher>(&self, hasher: &mut H) { 862 self.key.hash(hasher); 863 } 864 } 865 866 // Because we override `Hash` for this type, we must also override the 867 // `PartialEq` impl, so that all instances with the same hash are equal. 868 impl<K: PartialEq> PartialEq for Key<K> { 869 #[inline] eq(&self, other: &Self) -> bool870 fn eq(&self, other: &Self) -> bool { 871 self.key == other.key 872 } 873 } 874 875 impl<K: Eq> Eq for Key<K> {} 876 877 /// An iterator over the keys of a [`JoinMap`]. 878 #[derive(Debug, Clone)] 879 pub struct JoinMapKeys<'a, K, V> { 880 iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>, 881 /// To make it easier to change `JoinMap` in the future, keep V as a generic 882 /// parameter. 883 _value: PhantomData<&'a V>, 884 } 885 886 impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> { 887 type Item = &'a K; 888 next(&mut self) -> Option<&'a K>889 fn next(&mut self) -> Option<&'a K> { 890 self.iter.next().map(|key| &key.key) 891 } 892 size_hint(&self) -> (usize, Option<usize>)893 fn size_hint(&self) -> (usize, Option<usize>) { 894 self.iter.size_hint() 895 } 896 } 897 898 impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> { len(&self) -> usize899 fn len(&self) -> usize { 900 self.iter.len() 901 } 902 } 903 904 impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {} 905