1 use std::pin::Pin;
2 use tokio::sync::watch::Receiver;
3 
4 use futures_core::Stream;
5 use tokio_util::sync::ReusableBoxFuture;
6 
7 use std::fmt;
8 use std::task::{ready, Context, Poll};
9 use tokio::sync::watch::error::RecvError;
10 
11 /// A wrapper around [`tokio::sync::watch::Receiver`] that implements [`Stream`].
12 ///
13 /// This stream will start by yielding the current value when the `WatchStream` is polled,
14 /// regardless of whether it was the initial value or sent afterwards,
15 /// unless you use [`WatchStream<T>::from_changes`].
16 ///
17 /// # Examples
18 ///
19 /// ```
20 /// # #[tokio::main]
21 /// # async fn main() {
22 /// use tokio_stream::{StreamExt, wrappers::WatchStream};
23 /// use tokio::sync::watch;
24 ///
25 /// let (tx, rx) = watch::channel("hello");
26 /// let mut rx = WatchStream::new(rx);
27 ///
28 /// assert_eq!(rx.next().await, Some("hello"));
29 ///
30 /// tx.send("goodbye").unwrap();
31 /// assert_eq!(rx.next().await, Some("goodbye"));
32 /// # }
33 /// ```
34 ///
35 /// ```
36 /// # #[tokio::main]
37 /// # async fn main() {
38 /// use tokio_stream::{StreamExt, wrappers::WatchStream};
39 /// use tokio::sync::watch;
40 ///
41 /// let (tx, rx) = watch::channel("hello");
42 /// let mut rx = WatchStream::new(rx);
43 ///
44 /// // existing rx output with "hello" is ignored here
45 ///
46 /// tx.send("goodbye").unwrap();
47 /// assert_eq!(rx.next().await, Some("goodbye"));
48 /// # }
49 /// ```
50 ///
51 /// Example with [`WatchStream<T>::from_changes`]:
52 ///
53 /// ```
54 /// # #[tokio::main]
55 /// # async fn main() {
56 /// use futures::future::FutureExt;
57 /// use tokio::sync::watch;
58 /// use tokio_stream::{StreamExt, wrappers::WatchStream};
59 ///
60 /// let (tx, rx) = watch::channel("hello");
61 /// let mut rx = WatchStream::from_changes(rx);
62 ///
63 /// // no output from rx is available at this point - let's check this:
64 /// assert!(rx.next().now_or_never().is_none());
65 ///
66 /// tx.send("goodbye").unwrap();
67 /// assert_eq!(rx.next().await, Some("goodbye"));
68 /// # }
69 /// ```
70 ///
71 /// [`tokio::sync::watch::Receiver`]: struct@tokio::sync::watch::Receiver
72 /// [`Stream`]: trait@crate::Stream
73 #[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
74 pub struct WatchStream<T> {
75     inner: ReusableBoxFuture<'static, (Result<(), RecvError>, Receiver<T>)>,
76 }
77 
make_future<T: Clone + Send + Sync>( mut rx: Receiver<T>, ) -> (Result<(), RecvError>, Receiver<T>)78 async fn make_future<T: Clone + Send + Sync>(
79     mut rx: Receiver<T>,
80 ) -> (Result<(), RecvError>, Receiver<T>) {
81     let result = rx.changed().await;
82     (result, rx)
83 }
84 
85 impl<T: 'static + Clone + Send + Sync> WatchStream<T> {
86     /// Create a new `WatchStream`.
new(rx: Receiver<T>) -> Self87     pub fn new(rx: Receiver<T>) -> Self {
88         Self {
89             inner: ReusableBoxFuture::new(async move { (Ok(()), rx) }),
90         }
91     }
92 
93     /// Create a new `WatchStream` that waits for the value to be changed.
from_changes(rx: Receiver<T>) -> Self94     pub fn from_changes(rx: Receiver<T>) -> Self {
95         Self {
96             inner: ReusableBoxFuture::new(make_future(rx)),
97         }
98     }
99 }
100 
101 impl<T: Clone + 'static + Send + Sync> Stream for WatchStream<T> {
102     type Item = T;
103 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>104     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
105         let (result, mut rx) = ready!(self.inner.poll(cx));
106         match result {
107             Ok(_) => {
108                 let received = (*rx.borrow_and_update()).clone();
109                 self.inner.set(make_future(rx));
110                 Poll::Ready(Some(received))
111             }
112             Err(_) => {
113                 self.inner.set(make_future(rx));
114                 Poll::Ready(None)
115             }
116         }
117     }
118 }
119 
120 impl<T> Unpin for WatchStream<T> {}
121 
122 impl<T> fmt::Debug for WatchStream<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result123     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124         f.debug_struct("WatchStream").finish()
125     }
126 }
127 
128 impl<T: 'static + Clone + Send + Sync> From<Receiver<T>> for WatchStream<T> {
from(recv: Receiver<T>) -> Self129     fn from(recv: Receiver<T>) -> Self {
130         Self::new(recv)
131     }
132 }
133