1 use crate::transport::server::Connected;
2 use hyper::client::connect::{Connected as HyperConnected, Connection};
3 use std::io;
4 use std::io::IoSlice;
5 use std::pin::Pin;
6 use std::task::{Context, Poll};
7 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8 #[cfg(feature = "tls")]
9 use tokio_rustls::server::TlsStream;
10 
11 pub(in crate::transport) trait Io:
12     AsyncRead + AsyncWrite + Send + 'static
13 {
14 }
15 
16 impl<T> Io for T where T: AsyncRead + AsyncWrite + Send + 'static {}
17 
18 pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);
19 
20 impl BoxedIo {
new<I: Io>(io: I) -> Self21     pub(in crate::transport) fn new<I: Io>(io: I) -> Self {
22         BoxedIo(Box::pin(io))
23     }
24 }
25 
26 impl Connection for BoxedIo {
connected(&self) -> HyperConnected27     fn connected(&self) -> HyperConnected {
28         HyperConnected::new()
29     }
30 }
31 
32 impl Connected for BoxedIo {
33     type ConnectInfo = NoneConnectInfo;
34 
connect_info(&self) -> Self::ConnectInfo35     fn connect_info(&self) -> Self::ConnectInfo {
36         NoneConnectInfo
37     }
38 }
39 
40 #[derive(Copy, Clone)]
41 pub(crate) struct NoneConnectInfo;
42 
43 impl AsyncRead for BoxedIo {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>44     fn poll_read(
45         mut self: Pin<&mut Self>,
46         cx: &mut Context<'_>,
47         buf: &mut ReadBuf<'_>,
48     ) -> Poll<io::Result<()>> {
49         Pin::new(&mut self.0).poll_read(cx, buf)
50     }
51 }
52 
53 impl AsyncWrite for BoxedIo {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>54     fn poll_write(
55         mut self: Pin<&mut Self>,
56         cx: &mut Context<'_>,
57         buf: &[u8],
58     ) -> Poll<io::Result<usize>> {
59         Pin::new(&mut self.0).poll_write(cx, buf)
60     }
61 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>62     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
63         Pin::new(&mut self.0).poll_flush(cx)
64     }
65 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>66     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67         Pin::new(&mut self.0).poll_shutdown(cx)
68     }
69 
poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>70     fn poll_write_vectored(
71         mut self: Pin<&mut Self>,
72         cx: &mut Context<'_>,
73         bufs: &[IoSlice<'_>],
74     ) -> Poll<Result<usize, io::Error>> {
75         Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
76     }
77 
is_write_vectored(&self) -> bool78     fn is_write_vectored(&self) -> bool {
79         self.0.is_write_vectored()
80     }
81 }
82 
83 pub(crate) enum ServerIo<IO> {
84     Io(IO),
85     #[cfg(feature = "tls")]
86     TlsIo(Box<TlsStream<IO>>),
87 }
88 
89 use tower::util::Either;
90 
91 #[cfg(feature = "tls")]
92 type ServerIoConnectInfo<IO> =
93     Either<<IO as Connected>::ConnectInfo, <TlsStream<IO> as Connected>::ConnectInfo>;
94 
95 #[cfg(not(feature = "tls"))]
96 type ServerIoConnectInfo<IO> = Either<<IO as Connected>::ConnectInfo, ()>;
97 
98 impl<IO> ServerIo<IO> {
new_io(io: IO) -> Self99     pub(in crate::transport) fn new_io(io: IO) -> Self {
100         Self::Io(io)
101     }
102 
103     #[cfg(feature = "tls")]
new_tls_io(io: TlsStream<IO>) -> Self104     pub(in crate::transport) fn new_tls_io(io: TlsStream<IO>) -> Self {
105         Self::TlsIo(Box::new(io))
106     }
107 
108     #[cfg(feature = "tls")]
connect_info(&self) -> ServerIoConnectInfo<IO> where IO: Connected, TlsStream<IO>: Connected,109     pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo<IO>
110     where
111         IO: Connected,
112         TlsStream<IO>: Connected,
113     {
114         match self {
115             Self::Io(io) => Either::A(io.connect_info()),
116             Self::TlsIo(io) => Either::B(io.connect_info()),
117         }
118     }
119 
120     #[cfg(not(feature = "tls"))]
connect_info(&self) -> ServerIoConnectInfo<IO> where IO: Connected,121     pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo<IO>
122     where
123         IO: Connected,
124     {
125         match self {
126             Self::Io(io) => Either::A(io.connect_info()),
127         }
128     }
129 }
130 
131 impl<IO> AsyncRead for ServerIo<IO>
132 where
133     IO: AsyncWrite + AsyncRead + Unpin,
134 {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>135     fn poll_read(
136         mut self: Pin<&mut Self>,
137         cx: &mut Context<'_>,
138         buf: &mut ReadBuf<'_>,
139     ) -> Poll<io::Result<()>> {
140         match &mut *self {
141             Self::Io(io) => Pin::new(io).poll_read(cx, buf),
142             #[cfg(feature = "tls")]
143             Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf),
144         }
145     }
146 }
147 
148 impl<IO> AsyncWrite for ServerIo<IO>
149 where
150     IO: AsyncWrite + AsyncRead + Unpin,
151 {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>152     fn poll_write(
153         mut self: Pin<&mut Self>,
154         cx: &mut Context<'_>,
155         buf: &[u8],
156     ) -> Poll<io::Result<usize>> {
157         match &mut *self {
158             Self::Io(io) => Pin::new(io).poll_write(cx, buf),
159             #[cfg(feature = "tls")]
160             Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf),
161         }
162     }
163 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>164     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
165         match &mut *self {
166             Self::Io(io) => Pin::new(io).poll_flush(cx),
167             #[cfg(feature = "tls")]
168             Self::TlsIo(io) => Pin::new(io).poll_flush(cx),
169         }
170     }
171 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>172     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173         match &mut *self {
174             Self::Io(io) => Pin::new(io).poll_shutdown(cx),
175             #[cfg(feature = "tls")]
176             Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx),
177         }
178     }
179 
poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>180     fn poll_write_vectored(
181         mut self: Pin<&mut Self>,
182         cx: &mut Context<'_>,
183         bufs: &[IoSlice<'_>],
184     ) -> Poll<Result<usize, io::Error>> {
185         match &mut *self {
186             Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs),
187             #[cfg(feature = "tls")]
188             Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs),
189         }
190     }
191 
is_write_vectored(&self) -> bool192     fn is_write_vectored(&self) -> bool {
193         match self {
194             Self::Io(io) => io.is_write_vectored(),
195             #[cfg(feature = "tls")]
196             Self::TlsIo(io) => io.is_write_vectored(),
197         }
198     }
199 }
200