1 use super::{Connected, Server};
2 use crate::transport::service::ServerIo;
3 use hyper::server::{
4     accept::Accept,
5     conn::{AddrIncoming, AddrStream},
6 };
7 use std::{
8     net::SocketAddr,
9     pin::Pin,
10     task::{Context, Poll},
11     time::Duration,
12 };
13 use tokio::{
14     io::{AsyncRead, AsyncWrite},
15     net::TcpListener,
16 };
17 use tokio_stream::{Stream, StreamExt};
18 
19 #[cfg(not(feature = "tls"))]
tcp_incoming<IO, IE, L>( incoming: impl Stream<Item = Result<IO, IE>>, _server: Server<L>, ) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into<crate::Error>,20 pub(crate) fn tcp_incoming<IO, IE, L>(
21     incoming: impl Stream<Item = Result<IO, IE>>,
22     _server: Server<L>,
23 ) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
24 where
25     IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
26     IE: Into<crate::Error>,
27 {
28     async_stream::try_stream! {
29         tokio::pin!(incoming);
30 
31         while let Some(item) = incoming.next().await {
32             yield item.map(ServerIo::new_io)?
33         }
34     }
35 }
36 
37 #[cfg(feature = "tls")]
tcp_incoming<IO, IE, L>( incoming: impl Stream<Item = Result<IO, IE>>, server: Server<L>, ) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into<crate::Error>,38 pub(crate) fn tcp_incoming<IO, IE, L>(
39     incoming: impl Stream<Item = Result<IO, IE>>,
40     server: Server<L>,
41 ) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
42 where
43     IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
44     IE: Into<crate::Error>,
45 {
46     async_stream::try_stream! {
47         tokio::pin!(incoming);
48 
49         let mut tasks = tokio::task::JoinSet::new();
50 
51         loop {
52             match select(&mut incoming, &mut tasks).await {
53                 SelectOutput::Incoming(stream) => {
54                     if let Some(tls) = &server.tls {
55                         let tls = tls.clone();
56                         tasks.spawn(async move {
57                             let io = tls.accept(stream).await?;
58                             Ok(ServerIo::new_tls_io(io))
59                         });
60                     } else {
61                         yield ServerIo::new_io(stream);
62                     }
63                 }
64 
65                 SelectOutput::Io(io) => {
66                     yield io;
67                 }
68 
69                 SelectOutput::Err(e) => {
70                     tracing::debug!(message = "Accept loop error.", error = %e);
71                 }
72 
73                 SelectOutput::Done => {
74                     break;
75                 }
76             }
77         }
78     }
79 }
80 
81 #[cfg(feature = "tls")]
select<IO: 'static, IE>( incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin), tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::Error>>, ) -> SelectOutput<IO> where IE: Into<crate::Error>,82 async fn select<IO: 'static, IE>(
83     incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
84     tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::Error>>,
85 ) -> SelectOutput<IO>
86 where
87     IE: Into<crate::Error>,
88 {
89     if tasks.is_empty() {
90         return match incoming.try_next().await {
91             Ok(Some(stream)) => SelectOutput::Incoming(stream),
92             Ok(None) => SelectOutput::Done,
93             Err(e) => SelectOutput::Err(e.into()),
94         };
95     }
96 
97     tokio::select! {
98         stream = incoming.try_next() => {
99             match stream {
100                 Ok(Some(stream)) => SelectOutput::Incoming(stream),
101                 Ok(None) => SelectOutput::Done,
102                 Err(e) => SelectOutput::Err(e.into()),
103             }
104         }
105 
106         accept = tasks.join_next() => {
107             match accept.expect("JoinSet should never end") {
108                 Ok(Ok(io)) => SelectOutput::Io(io),
109                 Ok(Err(e)) => SelectOutput::Err(e),
110                 Err(e) => SelectOutput::Err(e.into()),
111             }
112         }
113     }
114 }
115 
116 #[cfg(feature = "tls")]
117 enum SelectOutput<A> {
118     Incoming(A),
119     Io(ServerIo<A>),
120     Err(crate::Error),
121     Done,
122 }
123 
124 /// Binds a socket address for a [Router](super::Router)
125 ///
126 /// An incoming stream, usable with [Router::serve_with_incoming](super::Router::serve_with_incoming),
127 /// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address.
128 #[derive(Debug)]
129 pub struct TcpIncoming {
130     inner: AddrIncoming,
131 }
132 
133 impl TcpIncoming {
134     /// Creates an instance by binding (opening) the specified socket address
135     /// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied.
136     /// Returns a TcpIncoming if the socket address was successfully bound.
137     ///
138     /// # Examples
139     /// ```no_run
140     /// # use tower_service::Service;
141     /// # use http::{request::Request, response::Response};
142     /// # use tonic::{body::BoxBody, server::NamedService, transport::{Body, Server, server::TcpIncoming}};
143     /// # use core::convert::Infallible;
144     /// # use std::error::Error;
145     /// # fn main() { }  // Cannot have type parameters, hence instead define:
146     /// # fn run<S>(some_service: S) -> Result<(), Box<dyn Error + Send + Sync>>
147     /// # where
148     /// #   S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible> + NamedService + Clone + Send + 'static,
149     /// #   S::Future: Send + 'static,
150     /// # {
151     /// // Find a free port
152     /// let mut port = 1322;
153     /// let tinc = loop {
154     ///    let addr = format!("127.0.0.1:{}", port).parse().unwrap();
155     ///    match TcpIncoming::new(addr, true, None) {
156     ///       Ok(t) => break t,
157     ///       Err(_) => port += 1
158     ///    }
159     /// };
160     /// Server::builder()
161     ///    .add_service(some_service)
162     ///    .serve_with_incoming(tinc);
163     /// # Ok(())
164     /// # }
new( addr: SocketAddr, nodelay: bool, keepalive: Option<Duration>, ) -> Result<Self, crate::Error>165     pub fn new(
166         addr: SocketAddr,
167         nodelay: bool,
168         keepalive: Option<Duration>,
169     ) -> Result<Self, crate::Error> {
170         let mut inner = AddrIncoming::bind(&addr)?;
171         inner.set_nodelay(nodelay);
172         inner.set_keepalive(keepalive);
173         Ok(TcpIncoming { inner })
174     }
175 
176     /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
from_listener( listener: TcpListener, nodelay: bool, keepalive: Option<Duration>, ) -> Result<Self, crate::Error>177     pub fn from_listener(
178         listener: TcpListener,
179         nodelay: bool,
180         keepalive: Option<Duration>,
181     ) -> Result<Self, crate::Error> {
182         let mut inner = AddrIncoming::from_listener(listener)?;
183         inner.set_nodelay(nodelay);
184         inner.set_keepalive(keepalive);
185         Ok(TcpIncoming { inner })
186     }
187 }
188 
189 impl Stream for TcpIncoming {
190     type Item = Result<AddrStream, std::io::Error>;
191 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>192     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
193         Pin::new(&mut self.inner).poll_accept(cx)
194     }
195 }
196 
197 #[cfg(test)]
198 mod tests {
199     use crate::transport::server::TcpIncoming;
200     #[tokio::test]
one_tcpincoming_at_a_time()201     async fn one_tcpincoming_at_a_time() {
202         let addr = "127.0.0.1:1322".parse().unwrap();
203         {
204             let _t1 = TcpIncoming::new(addr, true, None).unwrap();
205             let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
206         }
207         let _t3 = TcpIncoming::new(addr, true, None).unwrap();
208     }
209 }
210