1 use alloc::sync::Arc;
2 use core::{
3     cell::UnsafeCell,
4     convert::identity,
5     fmt,
6     marker::PhantomData,
7     num::NonZeroUsize,
8     pin::Pin,
9     sync::atomic::{AtomicU8, Ordering},
10 };
11 
12 use pin_project_lite::pin_project;
13 
14 use futures_core::{
15     future::Future,
16     ready,
17     stream::{FusedStream, Stream},
18     task::{Context, Poll, Waker},
19 };
20 #[cfg(feature = "sink")]
21 use futures_sink::Sink;
22 use futures_task::{waker, ArcWake};
23 
24 use crate::stream::FuturesUnordered;
25 
26 /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered)
27 /// method.
28 pub type FlattenUnordered<St> = FlattenUnorderedWithFlowController<St, ()>;
29 
30 /// There is nothing to poll and stream isn't being polled/waking/woken at the moment.
31 const NONE: u8 = 0;
32 
33 /// Inner streams need to be polled.
34 const NEED_TO_POLL_INNER_STREAMS: u8 = 1;
35 
36 /// The base stream needs to be polled.
37 const NEED_TO_POLL_STREAM: u8 = 0b10;
38 
39 /// Both base stream and inner streams need to be polled.
40 const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM;
41 
42 /// The current stream is being polled at the moment.
43 const POLLING: u8 = 0b100;
44 
45 /// Stream is being woken at the moment.
46 const WAKING: u8 = 0b1000;
47 
48 /// The stream was waked and will be polled.
49 const WOKEN: u8 = 0b10000;
50 
51 /// Internal polling state of the stream.
52 #[derive(Clone, Debug)]
53 struct SharedPollState {
54     state: Arc<AtomicU8>,
55 }
56 
57 impl SharedPollState {
58     /// Constructs new `SharedPollState` with the given state.
new(value: u8) -> Self59     fn new(value: u8) -> Self {
60         Self { state: Arc::new(AtomicU8::new(value)) }
61     }
62 
63     /// Attempts to start polling, returning stored state in case of success.
64     /// Returns `None` if either waker is waking at the moment.
start_polling(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)>65     fn start_polling(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)> {
66         let value = self
67             .state
68             .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
69                 if value & WAKING == NONE {
70                     Some(POLLING)
71                 } else {
72                     None
73                 }
74             })
75             .ok()?;
76         let bomb = PollStateBomb::new(self, Self::reset);
77 
78         Some((value, bomb))
79     }
80 
81     /// Attempts to start the waking process and performs bitwise or with the given value.
82     ///
83     /// If some waker is already in progress or stream is already woken/being polled, waking process won't start, however
84     /// state will be disjuncted with the given value.
start_waking( &self, to_poll: u8, ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)>85     fn start_waking(
86         &self,
87         to_poll: u8,
88     ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)> {
89         let value = self
90             .state
91             .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
92                 let mut next_value = value | to_poll;
93                 if value & (WOKEN | POLLING) == NONE {
94                     next_value |= WAKING;
95                 }
96 
97                 if next_value != value {
98                     Some(next_value)
99                 } else {
100                     None
101                 }
102             })
103             .ok()?;
104 
105         // Only start the waking process if we're not in the polling/waking phase and the stream isn't woken already
106         if value & (WOKEN | POLLING | WAKING) == NONE {
107             let bomb = PollStateBomb::new(self, Self::stop_waking);
108 
109             Some((value, bomb))
110         } else {
111             None
112         }
113     }
114 
115     /// Sets current state to
116     /// - `!POLLING` allowing to use wakers
117     /// - `WOKEN` if the state was changed during `POLLING` phase as waker will be called,
118     ///   or `will_be_woken` flag supplied
119     /// - `!WAKING` as
120     ///   * Wakers called during the `POLLING` phase won't propagate their calls
121     ///   * `POLLING` phase can't start if some of the wakers are active
122     ///     So no wrapped waker can touch the inner waker's cell, it's safe to poll again.
stop_polling(&self, to_poll: u8, will_be_woken: bool) -> u8123     fn stop_polling(&self, to_poll: u8, will_be_woken: bool) -> u8 {
124         self.state
125             .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut value| {
126                 let mut next_value = to_poll;
127 
128                 value &= NEED_TO_POLL_ALL;
129                 if value != NONE || will_be_woken {
130                     next_value |= WOKEN;
131                 }
132                 next_value |= value;
133 
134                 Some(next_value & !POLLING & !WAKING)
135             })
136             .unwrap()
137     }
138 
139     /// Toggles state to non-waking, allowing to start polling.
stop_waking(&self) -> u8140     fn stop_waking(&self) -> u8 {
141         let value = self
142             .state
143             .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
144                 let next_value = value & !WAKING | WOKEN;
145 
146                 if next_value != value {
147                     Some(next_value)
148                 } else {
149                     None
150                 }
151             })
152             .unwrap_or_else(identity);
153 
154         debug_assert!(value & (WOKEN | POLLING | WAKING) == WAKING);
155         value
156     }
157 
158     /// Resets current state allowing to poll the stream and wake up wakers.
reset(&self) -> u8159     fn reset(&self) -> u8 {
160         self.state.swap(NEED_TO_POLL_ALL, Ordering::SeqCst)
161     }
162 }
163 
164 /// Used to execute some function on the given state when dropped.
165 struct PollStateBomb<'a, F: FnOnce(&SharedPollState) -> u8> {
166     state: &'a SharedPollState,
167     drop: Option<F>,
168 }
169 
170 impl<'a, F: FnOnce(&SharedPollState) -> u8> PollStateBomb<'a, F> {
171     /// Constructs new bomb with the given state.
new(state: &'a SharedPollState, drop: F) -> Self172     fn new(state: &'a SharedPollState, drop: F) -> Self {
173         Self { state, drop: Some(drop) }
174     }
175 
176     /// Deactivates bomb, forces it to not call provided function when dropped.
deactivate(mut self)177     fn deactivate(mut self) {
178         self.drop.take();
179     }
180 }
181 
182 impl<F: FnOnce(&SharedPollState) -> u8> Drop for PollStateBomb<'_, F> {
drop(&mut self)183     fn drop(&mut self) {
184         if let Some(drop) = self.drop.take() {
185             (drop)(self.state);
186         }
187     }
188 }
189 
190 /// Will update state with the provided value on `wake_by_ref` call
191 /// and then, if there is a need, call `inner_waker`.
192 struct WrappedWaker {
193     inner_waker: UnsafeCell<Option<Waker>>,
194     poll_state: SharedPollState,
195     need_to_poll: u8,
196 }
197 
198 unsafe impl Send for WrappedWaker {}
199 unsafe impl Sync for WrappedWaker {}
200 
201 impl WrappedWaker {
202     /// Replaces given waker's inner_waker for polling stream/futures which will
203     /// update poll state on `wake_by_ref` call. Use only if you need several
204     /// contexts.
205     ///
206     /// ## Safety
207     ///
208     /// This function will modify waker's `inner_waker` via `UnsafeCell`, so
209     /// it should be used only during `POLLING` phase by one thread at the time.
replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>)210     unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) {
211         unsafe { *self_arc.inner_waker.get() = cx.waker().clone().into() }
212     }
213 
214     /// Attempts to start the waking process for the waker with the given value.
215     /// If succeeded, then the stream isn't yet woken and not being polled at the moment.
start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)>216     fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
217         self.poll_state.start_waking(self.need_to_poll)
218     }
219 }
220 
221 impl ArcWake for WrappedWaker {
wake_by_ref(self_arc: &Arc<Self>)222     fn wake_by_ref(self_arc: &Arc<Self>) {
223         if let Some((_, state_bomb)) = self_arc.start_waking() {
224             // Safety: now state is not `POLLING`
225             let waker_opt = unsafe { self_arc.inner_waker.get().as_ref().unwrap() };
226 
227             if let Some(inner_waker) = waker_opt.clone() {
228                 // Stop waking to allow polling stream
229                 drop(state_bomb);
230 
231                 // Wake up inner waker
232                 inner_waker.wake();
233             }
234         }
235     }
236 }
237 
238 pin_project! {
239     /// Future which polls optional inner stream.
240     ///
241     /// If it's `Some`, it will attempt to call `poll_next` on it,
242     /// returning `Some((item, next_item_fut))` in case of `Poll::Ready(Some(...))`
243     /// or `None` in case of `Poll::Ready(None)`.
244     ///
245     /// If `poll_next` will return `Poll::Pending`, it will be forwarded to
246     /// the future and current task will be notified by waker.
247     #[must_use = "futures do nothing unless you `.await` or poll them"]
248     struct PollStreamFut<St> {
249         #[pin]
250         stream: Option<St>,
251     }
252 }
253 
254 impl<St> PollStreamFut<St> {
255     /// Constructs new `PollStreamFut` using given `stream`.
new(stream: impl Into<Option<St>>) -> Self256     fn new(stream: impl Into<Option<St>>) -> Self {
257         Self { stream: stream.into() }
258     }
259 }
260 
261 impl<St: Stream + Unpin> Future for PollStreamFut<St> {
262     type Output = Option<(St::Item, Self)>;
263 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>264     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265         let mut stream = self.project().stream;
266 
267         let item = if let Some(stream) = stream.as_mut().as_pin_mut() {
268             ready!(stream.poll_next(cx))
269         } else {
270             None
271         };
272         let next_item_fut = Self::new(stream.get_mut().take());
273         let out = item.map(|item| (item, next_item_fut));
274 
275         Poll::Ready(out)
276     }
277 }
278 
279 pin_project! {
280     /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered)
281     /// method with ability to specify flow controller.
282     #[project = FlattenUnorderedWithFlowControllerProj]
283     #[must_use = "streams do nothing unless polled"]
284     pub struct FlattenUnorderedWithFlowController<St, Fc> where St: Stream {
285         #[pin]
286         inner_streams: FuturesUnordered<PollStreamFut<St::Item>>,
287         #[pin]
288         stream: St,
289         poll_state: SharedPollState,
290         limit: Option<NonZeroUsize>,
291         is_stream_done: bool,
292         inner_streams_waker: Arc<WrappedWaker>,
293         stream_waker: Arc<WrappedWaker>,
294         flow_controller: PhantomData<Fc>
295     }
296 }
297 
298 impl<St, Fc> fmt::Debug for FlattenUnorderedWithFlowController<St, Fc>
299 where
300     St: Stream + fmt::Debug,
301     St::Item: Stream + fmt::Debug,
302 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result303     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304         f.debug_struct("FlattenUnorderedWithFlowController")
305             .field("poll_state", &self.poll_state)
306             .field("inner_streams", &self.inner_streams)
307             .field("limit", &self.limit)
308             .field("stream", &self.stream)
309             .field("is_stream_done", &self.is_stream_done)
310             .field("flow_controller", &self.flow_controller)
311             .finish()
312     }
313 }
314 
315 impl<St, Fc> FlattenUnorderedWithFlowController<St, Fc>
316 where
317     St: Stream,
318     Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
319     St::Item: Stream + Unpin,
320 {
new(stream: St, limit: Option<usize>) -> Self321     pub(crate) fn new(stream: St, limit: Option<usize>) -> Self {
322         let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
323 
324         Self {
325             inner_streams: FuturesUnordered::new(),
326             stream,
327             is_stream_done: false,
328             limit: limit.and_then(NonZeroUsize::new),
329             inner_streams_waker: Arc::new(WrappedWaker {
330                 inner_waker: UnsafeCell::new(None),
331                 poll_state: poll_state.clone(),
332                 need_to_poll: NEED_TO_POLL_INNER_STREAMS,
333             }),
334             stream_waker: Arc::new(WrappedWaker {
335                 inner_waker: UnsafeCell::new(None),
336                 poll_state: poll_state.clone(),
337                 need_to_poll: NEED_TO_POLL_STREAM,
338             }),
339             poll_state,
340             flow_controller: PhantomData,
341         }
342     }
343 
344     delegate_access_inner!(stream, St, ());
345 }
346 
347 /// Returns the next flow step based on the received item.
348 pub trait FlowController<I, O> {
349     /// Handles an item producing `FlowStep` describing the next flow step.
next_step(item: I) -> FlowStep<I, O>350     fn next_step(item: I) -> FlowStep<I, O>;
351 }
352 
353 impl<I, O> FlowController<I, O> for () {
next_step(item: I) -> FlowStep<I, O>354     fn next_step(item: I) -> FlowStep<I, O> {
355         FlowStep::Continue(item)
356     }
357 }
358 
359 /// Describes the next flow step.
360 #[derive(Debug, Clone)]
361 pub enum FlowStep<C, R> {
362     /// Just yields an item and continues standard flow.
363     Continue(C),
364     /// Immediately returns an underlying item from the function.
365     Return(R),
366 }
367 
368 impl<St, Fc> FlattenUnorderedWithFlowControllerProj<'_, St, Fc>
369 where
370     St: Stream,
371 {
372     /// Checks if current `inner_streams` bucket size is greater than optional limit.
is_exceeded_limit(&self) -> bool373     fn is_exceeded_limit(&self) -> bool {
374         self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get())
375     }
376 }
377 
378 impl<St, Fc> FusedStream for FlattenUnorderedWithFlowController<St, Fc>
379 where
380     St: FusedStream,
381     Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
382     St::Item: Stream + Unpin,
383 {
is_terminated(&self) -> bool384     fn is_terminated(&self) -> bool {
385         self.stream.is_terminated() && self.inner_streams.is_empty()
386     }
387 }
388 
389 impl<St, Fc> Stream for FlattenUnorderedWithFlowController<St, Fc>
390 where
391     St: Stream,
392     Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
393     St::Item: Stream + Unpin,
394 {
395     type Item = <St::Item as Stream>::Item;
396 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>397     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398         let mut next_item = None;
399         let mut need_to_poll_next = NONE;
400 
401         let mut this = self.as_mut().project();
402 
403         // Attempt to start polling, in case some waker is holding the lock, wait in loop
404         let (mut poll_state_value, state_bomb) = loop {
405             if let Some(value) = this.poll_state.start_polling() {
406                 break value;
407             }
408         };
409 
410         // Safety: now state is `POLLING`.
411         unsafe {
412             WrappedWaker::replace_waker(this.stream_waker, cx);
413             WrappedWaker::replace_waker(this.inner_streams_waker, cx)
414         };
415 
416         if poll_state_value & NEED_TO_POLL_STREAM != NONE {
417             let mut stream_waker = None;
418 
419             // Here we need to poll the base stream.
420             //
421             // To improve performance, we will attempt to place as many items as we can
422             // to the `FuturesUnordered` bucket before polling inner streams
423             loop {
424                 if this.is_exceeded_limit() || *this.is_stream_done {
425                     // We either exceeded the limit or the stream is exhausted
426                     if !*this.is_stream_done {
427                         // The stream needs to be polled in the next iteration
428                         need_to_poll_next |= NEED_TO_POLL_STREAM;
429                     }
430 
431                     break;
432                 } else {
433                     let mut cx = Context::from_waker(
434                         stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())),
435                     );
436 
437                     match this.stream.as_mut().poll_next(&mut cx) {
438                         Poll::Ready(Some(item)) => {
439                             let next_item_fut = match Fc::next_step(item) {
440                                 // Propagates an item immediately (the main use-case is for errors)
441                                 FlowStep::Return(item) => {
442                                     need_to_poll_next |= NEED_TO_POLL_STREAM
443                                         | (poll_state_value & NEED_TO_POLL_INNER_STREAMS);
444                                     poll_state_value &= !NEED_TO_POLL_INNER_STREAMS;
445 
446                                     next_item = Some(item);
447 
448                                     break;
449                                 }
450                                 // Yields an item and continues processing (normal case)
451                                 FlowStep::Continue(inner_stream) => {
452                                     PollStreamFut::new(inner_stream)
453                                 }
454                             };
455                             // Add new stream to the inner streams bucket
456                             this.inner_streams.as_mut().push(next_item_fut);
457                             // Inner streams must be polled afterward
458                             poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
459                         }
460                         Poll::Ready(None) => {
461                             // Mark the base stream as done
462                             *this.is_stream_done = true;
463                         }
464                         Poll::Pending => {
465                             break;
466                         }
467                     }
468                 }
469             }
470         }
471 
472         if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
473             let inner_streams_waker = waker(this.inner_streams_waker.clone());
474             let mut cx = Context::from_waker(&inner_streams_waker);
475 
476             match this.inner_streams.as_mut().poll_next(&mut cx) {
477                 Poll::Ready(Some(Some((item, next_item_fut)))) => {
478                     // Push next inner stream item future to the list of inner streams futures
479                     this.inner_streams.as_mut().push(next_item_fut);
480                     // Take the received item
481                     next_item = Some(item);
482                     // On the next iteration, inner streams must be polled again
483                     need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
484                 }
485                 Poll::Ready(Some(None)) => {
486                     // On the next iteration, inner streams must be polled again
487                     need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
488                 }
489                 _ => {}
490             }
491         }
492 
493         // We didn't have any `poll_next` panic, so it's time to deactivate the bomb
494         state_bomb.deactivate();
495 
496         // Call the waker at the end of polling if
497         let mut force_wake =
498             // we need to poll the stream and didn't reach the limit yet
499             need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit()
500             // or we need to poll the inner streams again
501             || need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE;
502 
503         // Stop polling and swap the latest state
504         poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake);
505         // If state was changed during `POLLING` phase, we also need to manually call a waker
506         force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE;
507 
508         let is_done = *this.is_stream_done && this.inner_streams.is_empty();
509 
510         if next_item.is_some() || is_done {
511             Poll::Ready(next_item)
512         } else {
513             if force_wake {
514                 cx.waker().wake_by_ref();
515             }
516 
517             Poll::Pending
518         }
519     }
520 }
521 
522 // Forwarding impl of Sink from the underlying stream
523 #[cfg(feature = "sink")]
524 impl<St, Item, Fc> Sink<Item> for FlattenUnorderedWithFlowController<St, Fc>
525 where
526     St: Stream + Sink<Item>,
527 {
528     type Error = St::Error;
529 
530     delegate_sink!(stream, Item);
531 }
532