1 use proc_macro2::TokenStream;
2 use quote::{format_ident, quote, ToTokens};
3 use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token};
4 
5 /// The `stream_select!` macro.
stream_select(input: TokenStream) -> Result<TokenStream, syn::Error>6 pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> {
7     let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?;
8     if args.len() < 2 {
9         return Ok(quote! {
10            compile_error!("stream select macro needs at least two arguments.")
11         });
12     }
13     let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>();
14     let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>();
15     let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::<Vec<_>>();
16     let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>();
17     let args = args.iter().map(|e| e.to_token_stream());
18 
19     Ok(quote! {
20         {
21             #[derive(Debug)]
22             struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*);
23 
24             enum StreamEnum<#(#generic_idents),*> {
25                 #(
26                     #generic_idents(#generic_idents)
27                 ),*,
28                 None,
29             }
30 
31             impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*>
32             where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
33             {
34                 type Item = ITEM;
35 
36                 fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
37                     match self.get_mut() {
38                         #(
39                             Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx)
40                         ),*,
41                         Self::None => panic!("StreamEnum::None should never be polled!"),
42                     }
43                 }
44             }
45 
46             impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*>
47             where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
48             {
49                 type Item = ITEM;
50 
51                 fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
52                     let Self(#(ref mut #field_idents),*) = self.get_mut();
53                     #(
54                         let mut #field_idents_2 = false;
55                     )*
56                     let mut any_pending = false;
57                     {
58                         let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*];
59                         __futures_crate::async_await::shuffle(&mut stream_array);
60 
61                         for mut s in stream_array {
62                             if let StreamEnum::None = s {
63                                 continue;
64                             } else {
65                                 match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) {
66                                     r @ __futures_crate::task::Poll::Ready(Some(_)) => {
67                                         return r;
68                                     },
69                                     __futures_crate::task::Poll::Pending => {
70                                         any_pending = true;
71                                     },
72                                     __futures_crate::task::Poll::Ready(None) => {
73                                         match s {
74                                             #(
75                                                 StreamEnum::#generic_idents(_) => { #field_idents_2 = true; }
76                                             ),*,
77                                             StreamEnum::None => panic!("StreamEnum::None should never be polled!"),
78                                         }
79                                     },
80                                 }
81                             }
82                         }
83                     }
84                     #(
85                         if #field_idents_2 {
86                             *#field_idents = None;
87                         }
88                     )*
89                     if any_pending {
90                         __futures_crate::task::Poll::Pending
91                     } else {
92                         __futures_crate::task::Poll::Ready(None)
93                     }
94                 }
95 
96                 fn size_hint(&self) -> (usize, Option<usize>) {
97                     let mut s = (0, Some(0));
98                     #(
99                         if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) {
100                             s.0 += new_hint.0;
101                             // We can change this out for `.zip` when the MSRV is 1.46.0 or higher.
102                             s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b));
103                         }
104                     )*
105                     s
106                 }
107             }
108 
109             StreamSelect(#(Some(#args)),*)
110 
111         }
112     })
113 }
114