1 //! Convenience wrapper for streams to switch between plain TCP and TLS at runtime. 2 //! 3 //! There is no dependency on actual TLS implementations. Everything like 4 //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard 5 //! `Read + Write` traits. 6 7 #[cfg(feature = "__rustls-tls")] 8 use std::ops::Deref; 9 use std::{ 10 fmt::{self, Debug}, 11 io::{Read, Result as IoResult, Write}, 12 }; 13 14 use std::net::TcpStream; 15 16 #[cfg(feature = "native-tls")] 17 use native_tls_crate::TlsStream; 18 #[cfg(feature = "__rustls-tls")] 19 use rustls::StreamOwned; 20 21 /// Stream mode, either plain TCP or TLS. 22 #[derive(Clone, Copy, Debug)] 23 pub enum Mode { 24 /// Plain mode (`ws://` URL). 25 Plain, 26 /// TLS mode (`wss://` URL). 27 Tls, 28 } 29 30 /// Trait to switch TCP_NODELAY. 31 pub trait NoDelay { 32 /// Set the TCP_NODELAY option to the given value. set_nodelay(&mut self, nodelay: bool) -> IoResult<()>33 fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()>; 34 } 35 36 impl NoDelay for TcpStream { set_nodelay(&mut self, nodelay: bool) -> IoResult<()>37 fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { 38 TcpStream::set_nodelay(self, nodelay) 39 } 40 } 41 42 #[cfg(feature = "native-tls")] 43 impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> { set_nodelay(&mut self, nodelay: bool) -> IoResult<()>44 fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { 45 self.get_mut().set_nodelay(nodelay) 46 } 47 } 48 49 #[cfg(feature = "__rustls-tls")] 50 impl<S, SD, T> NoDelay for StreamOwned<S, T> 51 where 52 S: Deref<Target = rustls::ConnectionCommon<SD>>, 53 SD: rustls::SideData, 54 T: Read + Write + NoDelay, 55 { set_nodelay(&mut self, nodelay: bool) -> IoResult<()>56 fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { 57 self.sock.set_nodelay(nodelay) 58 } 59 } 60 61 /// A stream that might be protected with TLS. 62 #[non_exhaustive] 63 pub enum MaybeTlsStream<S: Read + Write> { 64 /// Unencrypted socket stream. 65 Plain(S), 66 #[cfg(feature = "native-tls")] 67 /// Encrypted socket stream using `native-tls`. 68 NativeTls(native_tls_crate::TlsStream<S>), 69 #[cfg(feature = "__rustls-tls")] 70 /// Encrypted socket stream using `rustls`. 71 Rustls(rustls::StreamOwned<rustls::ClientConnection, S>), 72 } 73 74 impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 76 match self { 77 Self::Plain(s) => f.debug_tuple("MaybeTlsStream::Plain").field(s).finish(), 78 #[cfg(feature = "native-tls")] 79 Self::NativeTls(s) => f.debug_tuple("MaybeTlsStream::NativeTls").field(s).finish(), 80 #[cfg(feature = "__rustls-tls")] 81 Self::Rustls(s) => { 82 struct RustlsStreamDebug<'a, S: Read + Write>( 83 &'a rustls::StreamOwned<rustls::ClientConnection, S>, 84 ); 85 86 impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> { 87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 88 f.debug_struct("StreamOwned") 89 .field("conn", &self.0.conn) 90 .field("sock", &self.0.sock) 91 .finish() 92 } 93 } 94 95 f.debug_tuple("MaybeTlsStream::Rustls").field(&RustlsStreamDebug(s)).finish() 96 } 97 } 98 } 99 } 100 101 impl<S: Read + Write> Read for MaybeTlsStream<S> { read(&mut self, buf: &mut [u8]) -> IoResult<usize>102 fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> { 103 match *self { 104 MaybeTlsStream::Plain(ref mut s) => s.read(buf), 105 #[cfg(feature = "native-tls")] 106 MaybeTlsStream::NativeTls(ref mut s) => s.read(buf), 107 #[cfg(feature = "__rustls-tls")] 108 MaybeTlsStream::Rustls(ref mut s) => s.read(buf), 109 } 110 } 111 } 112 113 impl<S: Read + Write> Write for MaybeTlsStream<S> { write(&mut self, buf: &[u8]) -> IoResult<usize>114 fn write(&mut self, buf: &[u8]) -> IoResult<usize> { 115 match *self { 116 MaybeTlsStream::Plain(ref mut s) => s.write(buf), 117 #[cfg(feature = "native-tls")] 118 MaybeTlsStream::NativeTls(ref mut s) => s.write(buf), 119 #[cfg(feature = "__rustls-tls")] 120 MaybeTlsStream::Rustls(ref mut s) => s.write(buf), 121 } 122 } 123 flush(&mut self) -> IoResult<()>124 fn flush(&mut self) -> IoResult<()> { 125 match *self { 126 MaybeTlsStream::Plain(ref mut s) => s.flush(), 127 #[cfg(feature = "native-tls")] 128 MaybeTlsStream::NativeTls(ref mut s) => s.flush(), 129 #[cfg(feature = "__rustls-tls")] 130 MaybeTlsStream::Rustls(ref mut s) => s.flush(), 131 } 132 } 133 } 134 135 impl<S: Read + Write + NoDelay> NoDelay for MaybeTlsStream<S> { set_nodelay(&mut self, nodelay: bool) -> IoResult<()>136 fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { 137 match *self { 138 MaybeTlsStream::Plain(ref mut s) => s.set_nodelay(nodelay), 139 #[cfg(feature = "native-tls")] 140 MaybeTlsStream::NativeTls(ref mut s) => s.set_nodelay(nodelay), 141 #[cfg(feature = "__rustls-tls")] 142 MaybeTlsStream::Rustls(ref mut s) => s.set_nodelay(nodelay), 143 } 144 } 145 } 146