1 use super::{Connected, Server};
2 use crate::transport::service::ServerIo;
3 use hyper::server::{
4 accept::Accept,
5 conn::{AddrIncoming, AddrStream},
6 };
7 use std::{
8 net::SocketAddr,
9 pin::Pin,
10 task::{Context, Poll},
11 time::Duration,
12 };
13 use tokio::{
14 io::{AsyncRead, AsyncWrite},
15 net::TcpListener,
16 };
17 use tokio_stream::{Stream, StreamExt};
18
19 #[cfg(not(feature = "tls"))]
tcp_incoming<IO, IE, L>( incoming: impl Stream<Item = Result<IO, IE>>, _server: Server<L>, ) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into<crate::Error>,20 pub(crate) fn tcp_incoming<IO, IE, L>(
21 incoming: impl Stream<Item = Result<IO, IE>>,
22 _server: Server<L>,
23 ) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
24 where
25 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
26 IE: Into<crate::Error>,
27 {
28 async_stream::try_stream! {
29 tokio::pin!(incoming);
30
31 while let Some(item) = incoming.next().await {
32 yield item.map(ServerIo::new_io)?
33 }
34 }
35 }
36
37 #[cfg(feature = "tls")]
tcp_incoming<IO, IE, L>( incoming: impl Stream<Item = Result<IO, IE>>, server: Server<L>, ) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into<crate::Error>,38 pub(crate) fn tcp_incoming<IO, IE, L>(
39 incoming: impl Stream<Item = Result<IO, IE>>,
40 server: Server<L>,
41 ) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
42 where
43 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
44 IE: Into<crate::Error>,
45 {
46 async_stream::try_stream! {
47 tokio::pin!(incoming);
48
49 let mut tasks = tokio::task::JoinSet::new();
50
51 loop {
52 match select(&mut incoming, &mut tasks).await {
53 SelectOutput::Incoming(stream) => {
54 if let Some(tls) = &server.tls {
55 let tls = tls.clone();
56 tasks.spawn(async move {
57 let io = tls.accept(stream).await?;
58 Ok(ServerIo::new_tls_io(io))
59 });
60 } else {
61 yield ServerIo::new_io(stream);
62 }
63 }
64
65 SelectOutput::Io(io) => {
66 yield io;
67 }
68
69 SelectOutput::Err(e) => {
70 tracing::debug!(message = "Accept loop error.", error = %e);
71 }
72
73 SelectOutput::Done => {
74 break;
75 }
76 }
77 }
78 }
79 }
80
81 #[cfg(feature = "tls")]
select<IO: 'static, IE>( incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin), tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::Error>>, ) -> SelectOutput<IO> where IE: Into<crate::Error>,82 async fn select<IO: 'static, IE>(
83 incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
84 tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::Error>>,
85 ) -> SelectOutput<IO>
86 where
87 IE: Into<crate::Error>,
88 {
89 if tasks.is_empty() {
90 return match incoming.try_next().await {
91 Ok(Some(stream)) => SelectOutput::Incoming(stream),
92 Ok(None) => SelectOutput::Done,
93 Err(e) => SelectOutput::Err(e.into()),
94 };
95 }
96
97 tokio::select! {
98 stream = incoming.try_next() => {
99 match stream {
100 Ok(Some(stream)) => SelectOutput::Incoming(stream),
101 Ok(None) => SelectOutput::Done,
102 Err(e) => SelectOutput::Err(e.into()),
103 }
104 }
105
106 accept = tasks.join_next() => {
107 match accept.expect("JoinSet should never end") {
108 Ok(Ok(io)) => SelectOutput::Io(io),
109 Ok(Err(e)) => SelectOutput::Err(e),
110 Err(e) => SelectOutput::Err(e.into()),
111 }
112 }
113 }
114 }
115
116 #[cfg(feature = "tls")]
117 enum SelectOutput<A> {
118 Incoming(A),
119 Io(ServerIo<A>),
120 Err(crate::Error),
121 Done,
122 }
123
124 /// Binds a socket address for a [Router](super::Router)
125 ///
126 /// An incoming stream, usable with [Router::serve_with_incoming](super::Router::serve_with_incoming),
127 /// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address.
128 #[derive(Debug)]
129 pub struct TcpIncoming {
130 inner: AddrIncoming,
131 }
132
133 impl TcpIncoming {
134 /// Creates an instance by binding (opening) the specified socket address
135 /// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied.
136 /// Returns a TcpIncoming if the socket address was successfully bound.
137 ///
138 /// # Examples
139 /// ```no_run
140 /// # use tower_service::Service;
141 /// # use http::{request::Request, response::Response};
142 /// # use tonic::{body::BoxBody, server::NamedService, transport::{Body, Server, server::TcpIncoming}};
143 /// # use core::convert::Infallible;
144 /// # use std::error::Error;
145 /// # fn main() { } // Cannot have type parameters, hence instead define:
146 /// # fn run<S>(some_service: S) -> Result<(), Box<dyn Error + Send + Sync>>
147 /// # where
148 /// # S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible> + NamedService + Clone + Send + 'static,
149 /// # S::Future: Send + 'static,
150 /// # {
151 /// // Find a free port
152 /// let mut port = 1322;
153 /// let tinc = loop {
154 /// let addr = format!("127.0.0.1:{}", port).parse().unwrap();
155 /// match TcpIncoming::new(addr, true, None) {
156 /// Ok(t) => break t,
157 /// Err(_) => port += 1
158 /// }
159 /// };
160 /// Server::builder()
161 /// .add_service(some_service)
162 /// .serve_with_incoming(tinc);
163 /// # Ok(())
164 /// # }
new( addr: SocketAddr, nodelay: bool, keepalive: Option<Duration>, ) -> Result<Self, crate::Error>165 pub fn new(
166 addr: SocketAddr,
167 nodelay: bool,
168 keepalive: Option<Duration>,
169 ) -> Result<Self, crate::Error> {
170 let mut inner = AddrIncoming::bind(&addr)?;
171 inner.set_nodelay(nodelay);
172 inner.set_keepalive(keepalive);
173 Ok(TcpIncoming { inner })
174 }
175
176 /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
from_listener( listener: TcpListener, nodelay: bool, keepalive: Option<Duration>, ) -> Result<Self, crate::Error>177 pub fn from_listener(
178 listener: TcpListener,
179 nodelay: bool,
180 keepalive: Option<Duration>,
181 ) -> Result<Self, crate::Error> {
182 let mut inner = AddrIncoming::from_listener(listener)?;
183 inner.set_nodelay(nodelay);
184 inner.set_keepalive(keepalive);
185 Ok(TcpIncoming { inner })
186 }
187 }
188
189 impl Stream for TcpIncoming {
190 type Item = Result<AddrStream, std::io::Error>;
191
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>192 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
193 Pin::new(&mut self.inner).poll_accept(cx)
194 }
195 }
196
197 #[cfg(test)]
198 mod tests {
199 use crate::transport::server::TcpIncoming;
200 #[tokio::test]
one_tcpincoming_at_a_time()201 async fn one_tcpincoming_at_a_time() {
202 let addr = "127.0.0.1:1322".parse().unwrap();
203 {
204 let _t1 = TcpIncoming::new(addr, true, None).unwrap();
205 let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
206 }
207 let _t3 = TcpIncoming::new(addr, true, None).unwrap();
208 }
209 }
210