1 use crate::transport::server::Connected; 2 use hyper::client::connect::{Connected as HyperConnected, Connection}; 3 use std::io; 4 use std::io::IoSlice; 5 use std::pin::Pin; 6 use std::task::{Context, Poll}; 7 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 8 #[cfg(feature = "tls")] 9 use tokio_rustls::server::TlsStream; 10 11 pub(in crate::transport) trait Io: 12 AsyncRead + AsyncWrite + Send + 'static 13 { 14 } 15 16 impl<T> Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} 17 18 pub(crate) struct BoxedIo(Pin<Box<dyn Io>>); 19 20 impl BoxedIo { new<I: Io>(io: I) -> Self21 pub(in crate::transport) fn new<I: Io>(io: I) -> Self { 22 BoxedIo(Box::pin(io)) 23 } 24 } 25 26 impl Connection for BoxedIo { connected(&self) -> HyperConnected27 fn connected(&self) -> HyperConnected { 28 HyperConnected::new() 29 } 30 } 31 32 impl Connected for BoxedIo { 33 type ConnectInfo = NoneConnectInfo; 34 connect_info(&self) -> Self::ConnectInfo35 fn connect_info(&self) -> Self::ConnectInfo { 36 NoneConnectInfo 37 } 38 } 39 40 #[derive(Copy, Clone)] 41 pub(crate) struct NoneConnectInfo; 42 43 impl AsyncRead for BoxedIo { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>44 fn poll_read( 45 mut self: Pin<&mut Self>, 46 cx: &mut Context<'_>, 47 buf: &mut ReadBuf<'_>, 48 ) -> Poll<io::Result<()>> { 49 Pin::new(&mut self.0).poll_read(cx, buf) 50 } 51 } 52 53 impl AsyncWrite for BoxedIo { poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>54 fn poll_write( 55 mut self: Pin<&mut Self>, 56 cx: &mut Context<'_>, 57 buf: &[u8], 58 ) -> Poll<io::Result<usize>> { 59 Pin::new(&mut self.0).poll_write(cx, buf) 60 } 61 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>62 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 63 Pin::new(&mut self.0).poll_flush(cx) 64 } 65 poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>66 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 67 Pin::new(&mut self.0).poll_shutdown(cx) 68 } 69 poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>70 fn poll_write_vectored( 71 mut self: Pin<&mut Self>, 72 cx: &mut Context<'_>, 73 bufs: &[IoSlice<'_>], 74 ) -> Poll<Result<usize, io::Error>> { 75 Pin::new(&mut self.0).poll_write_vectored(cx, bufs) 76 } 77 is_write_vectored(&self) -> bool78 fn is_write_vectored(&self) -> bool { 79 self.0.is_write_vectored() 80 } 81 } 82 83 pub(crate) enum ServerIo<IO> { 84 Io(IO), 85 #[cfg(feature = "tls")] 86 TlsIo(Box<TlsStream<IO>>), 87 } 88 89 use tower::util::Either; 90 91 #[cfg(feature = "tls")] 92 type ServerIoConnectInfo<IO> = 93 Either<<IO as Connected>::ConnectInfo, <TlsStream<IO> as Connected>::ConnectInfo>; 94 95 #[cfg(not(feature = "tls"))] 96 type ServerIoConnectInfo<IO> = Either<<IO as Connected>::ConnectInfo, ()>; 97 98 impl<IO> ServerIo<IO> { new_io(io: IO) -> Self99 pub(in crate::transport) fn new_io(io: IO) -> Self { 100 Self::Io(io) 101 } 102 103 #[cfg(feature = "tls")] new_tls_io(io: TlsStream<IO>) -> Self104 pub(in crate::transport) fn new_tls_io(io: TlsStream<IO>) -> Self { 105 Self::TlsIo(Box::new(io)) 106 } 107 108 #[cfg(feature = "tls")] connect_info(&self) -> ServerIoConnectInfo<IO> where IO: Connected, TlsStream<IO>: Connected,109 pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo<IO> 110 where 111 IO: Connected, 112 TlsStream<IO>: Connected, 113 { 114 match self { 115 Self::Io(io) => Either::A(io.connect_info()), 116 Self::TlsIo(io) => Either::B(io.connect_info()), 117 } 118 } 119 120 #[cfg(not(feature = "tls"))] connect_info(&self) -> ServerIoConnectInfo<IO> where IO: Connected,121 pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo<IO> 122 where 123 IO: Connected, 124 { 125 match self { 126 Self::Io(io) => Either::A(io.connect_info()), 127 } 128 } 129 } 130 131 impl<IO> AsyncRead for ServerIo<IO> 132 where 133 IO: AsyncWrite + AsyncRead + Unpin, 134 { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>135 fn poll_read( 136 mut self: Pin<&mut Self>, 137 cx: &mut Context<'_>, 138 buf: &mut ReadBuf<'_>, 139 ) -> Poll<io::Result<()>> { 140 match &mut *self { 141 Self::Io(io) => Pin::new(io).poll_read(cx, buf), 142 #[cfg(feature = "tls")] 143 Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf), 144 } 145 } 146 } 147 148 impl<IO> AsyncWrite for ServerIo<IO> 149 where 150 IO: AsyncWrite + AsyncRead + Unpin, 151 { poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>152 fn poll_write( 153 mut self: Pin<&mut Self>, 154 cx: &mut Context<'_>, 155 buf: &[u8], 156 ) -> Poll<io::Result<usize>> { 157 match &mut *self { 158 Self::Io(io) => Pin::new(io).poll_write(cx, buf), 159 #[cfg(feature = "tls")] 160 Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf), 161 } 162 } 163 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>164 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 165 match &mut *self { 166 Self::Io(io) => Pin::new(io).poll_flush(cx), 167 #[cfg(feature = "tls")] 168 Self::TlsIo(io) => Pin::new(io).poll_flush(cx), 169 } 170 } 171 poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>172 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 173 match &mut *self { 174 Self::Io(io) => Pin::new(io).poll_shutdown(cx), 175 #[cfg(feature = "tls")] 176 Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx), 177 } 178 } 179 poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>180 fn poll_write_vectored( 181 mut self: Pin<&mut Self>, 182 cx: &mut Context<'_>, 183 bufs: &[IoSlice<'_>], 184 ) -> Poll<Result<usize, io::Error>> { 185 match &mut *self { 186 Self::Io(io) => Pin::new(io).poll_write_vectored(cx, bufs), 187 #[cfg(feature = "tls")] 188 Self::TlsIo(io) => Pin::new(io).poll_write_vectored(cx, bufs), 189 } 190 } 191 is_write_vectored(&self) -> bool192 fn is_write_vectored(&self) -> bool { 193 match self { 194 Self::Io(io) => io.is_write_vectored(), 195 #[cfg(feature = "tls")] 196 Self::TlsIo(io) => io.is_write_vectored(), 197 } 198 } 199 } 200