1 //! Convenience wrapper for streams to switch between plain TCP and TLS at runtime.
2 //!
3 //!  There is no dependency on actual TLS implementations. Everything like
4 //! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
5 //! `Read + Write` traits.
6 
7 #[cfg(feature = "__rustls-tls")]
8 use std::ops::Deref;
9 use std::{
10     fmt::{self, Debug},
11     io::{Read, Result as IoResult, Write},
12 };
13 
14 use std::net::TcpStream;
15 
16 #[cfg(feature = "native-tls")]
17 use native_tls_crate::TlsStream;
18 #[cfg(feature = "__rustls-tls")]
19 use rustls::StreamOwned;
20 
21 /// Stream mode, either plain TCP or TLS.
22 #[derive(Clone, Copy, Debug)]
23 pub enum Mode {
24     /// Plain mode (`ws://` URL).
25     Plain,
26     /// TLS mode (`wss://` URL).
27     Tls,
28 }
29 
30 /// Trait to switch TCP_NODELAY.
31 pub trait NoDelay {
32     /// Set the TCP_NODELAY option to the given value.
set_nodelay(&mut self, nodelay: bool) -> IoResult<()>33     fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()>;
34 }
35 
36 impl NoDelay for TcpStream {
set_nodelay(&mut self, nodelay: bool) -> IoResult<()>37     fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
38         TcpStream::set_nodelay(self, nodelay)
39     }
40 }
41 
42 #[cfg(feature = "native-tls")]
43 impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
set_nodelay(&mut self, nodelay: bool) -> IoResult<()>44     fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
45         self.get_mut().set_nodelay(nodelay)
46     }
47 }
48 
49 #[cfg(feature = "__rustls-tls")]
50 impl<S, SD, T> NoDelay for StreamOwned<S, T>
51 where
52     S: Deref<Target = rustls::ConnectionCommon<SD>>,
53     SD: rustls::SideData,
54     T: Read + Write + NoDelay,
55 {
set_nodelay(&mut self, nodelay: bool) -> IoResult<()>56     fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
57         self.sock.set_nodelay(nodelay)
58     }
59 }
60 
61 /// A stream that might be protected with TLS.
62 #[non_exhaustive]
63 pub enum MaybeTlsStream<S: Read + Write> {
64     /// Unencrypted socket stream.
65     Plain(S),
66     #[cfg(feature = "native-tls")]
67     /// Encrypted socket stream using `native-tls`.
68     NativeTls(native_tls_crate::TlsStream<S>),
69     #[cfg(feature = "__rustls-tls")]
70     /// Encrypted socket stream using `rustls`.
71     Rustls(rustls::StreamOwned<rustls::ClientConnection, S>),
72 }
73 
74 impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result75     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76         match self {
77             Self::Plain(s) => f.debug_tuple("MaybeTlsStream::Plain").field(s).finish(),
78             #[cfg(feature = "native-tls")]
79             Self::NativeTls(s) => f.debug_tuple("MaybeTlsStream::NativeTls").field(s).finish(),
80             #[cfg(feature = "__rustls-tls")]
81             Self::Rustls(s) => {
82                 struct RustlsStreamDebug<'a, S: Read + Write>(
83                     &'a rustls::StreamOwned<rustls::ClientConnection, S>,
84                 );
85 
86                 impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> {
87                     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88                         f.debug_struct("StreamOwned")
89                             .field("conn", &self.0.conn)
90                             .field("sock", &self.0.sock)
91                             .finish()
92                     }
93                 }
94 
95                 f.debug_tuple("MaybeTlsStream::Rustls").field(&RustlsStreamDebug(s)).finish()
96             }
97         }
98     }
99 }
100 
101 impl<S: Read + Write> Read for MaybeTlsStream<S> {
read(&mut self, buf: &mut [u8]) -> IoResult<usize>102     fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
103         match *self {
104             MaybeTlsStream::Plain(ref mut s) => s.read(buf),
105             #[cfg(feature = "native-tls")]
106             MaybeTlsStream::NativeTls(ref mut s) => s.read(buf),
107             #[cfg(feature = "__rustls-tls")]
108             MaybeTlsStream::Rustls(ref mut s) => s.read(buf),
109         }
110     }
111 }
112 
113 impl<S: Read + Write> Write for MaybeTlsStream<S> {
write(&mut self, buf: &[u8]) -> IoResult<usize>114     fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
115         match *self {
116             MaybeTlsStream::Plain(ref mut s) => s.write(buf),
117             #[cfg(feature = "native-tls")]
118             MaybeTlsStream::NativeTls(ref mut s) => s.write(buf),
119             #[cfg(feature = "__rustls-tls")]
120             MaybeTlsStream::Rustls(ref mut s) => s.write(buf),
121         }
122     }
123 
flush(&mut self) -> IoResult<()>124     fn flush(&mut self) -> IoResult<()> {
125         match *self {
126             MaybeTlsStream::Plain(ref mut s) => s.flush(),
127             #[cfg(feature = "native-tls")]
128             MaybeTlsStream::NativeTls(ref mut s) => s.flush(),
129             #[cfg(feature = "__rustls-tls")]
130             MaybeTlsStream::Rustls(ref mut s) => s.flush(),
131         }
132     }
133 }
134 
135 impl<S: Read + Write + NoDelay> NoDelay for MaybeTlsStream<S> {
set_nodelay(&mut self, nodelay: bool) -> IoResult<()>136     fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
137         match *self {
138             MaybeTlsStream::Plain(ref mut s) => s.set_nodelay(nodelay),
139             #[cfg(feature = "native-tls")]
140             MaybeTlsStream::NativeTls(ref mut s) => s.set_nodelay(nodelay),
141             #[cfg(feature = "__rustls-tls")]
142             MaybeTlsStream::Rustls(ref mut s) => s.set_nodelay(nodelay),
143         }
144     }
145 }
146