1 //! WebSocket handshake control.
2
3 pub mod client;
4 pub mod headers;
5 pub mod machine;
6 pub mod server;
7
8 use std::{
9 error::Error as ErrorTrait,
10 fmt,
11 io::{Read, Write},
12 };
13
14 use sha1::{Digest, Sha1};
15
16 use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
17 use crate::error::Error;
18
19 /// A WebSocket handshake.
20 #[derive(Debug)]
21 pub struct MidHandshake<Role: HandshakeRole> {
22 role: Role,
23 machine: HandshakeMachine<Role::InternalStream>,
24 }
25
26 impl<Role: HandshakeRole> MidHandshake<Role> {
27 /// Allow access to machine
get_ref(&self) -> &HandshakeMachine<Role::InternalStream>28 pub fn get_ref(&self) -> &HandshakeMachine<Role::InternalStream> {
29 &self.machine
30 }
31
32 /// Allow mutable access to machine
get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream>33 pub fn get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream> {
34 &mut self.machine
35 }
36
37 /// Restarts the handshake process.
handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>>38 pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
39 let mut mach = self.machine;
40 loop {
41 mach = match mach.single_round()? {
42 RoundResult::WouldBlock(m) => {
43 return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
44 }
45 RoundResult::Incomplete(m) => m,
46 RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
47 ProcessingResult::Continue(m) => m,
48 ProcessingResult::Done(result) => return Ok(result),
49 },
50 }
51 }
52 }
53 }
54
55 /// A handshake result.
56 pub enum HandshakeError<Role: HandshakeRole> {
57 /// Handshake was interrupted (would block).
58 Interrupted(MidHandshake<Role>),
59 /// Handshake failed.
60 Failure(Error),
61 }
62
63 impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 match *self {
66 HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
67 HandshakeError::Failure(ref e) => write!(f, "HandshakeError::Failure({:?})", e),
68 }
69 }
70 }
71
72 impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result73 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74 match *self {
75 HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
76 HandshakeError::Failure(ref e) => write!(f, "{}", e),
77 }
78 }
79 }
80
81 impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {}
82
83 impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
from(err: Error) -> Self84 fn from(err: Error) -> Self {
85 HandshakeError::Failure(err)
86 }
87 }
88
89 /// Handshake role.
90 pub trait HandshakeRole {
91 #[doc(hidden)]
92 type IncomingData: TryParse;
93 #[doc(hidden)]
94 type InternalStream: Read + Write;
95 #[doc(hidden)]
96 type FinalResult;
97 #[doc(hidden)]
stage_finished( &mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>, ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>98 fn stage_finished(
99 &mut self,
100 finish: StageResult<Self::IncomingData, Self::InternalStream>,
101 ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
102 }
103
104 /// Stage processing result.
105 #[doc(hidden)]
106 #[derive(Debug)]
107 pub enum ProcessingResult<Stream, FinalResult> {
108 Continue(HandshakeMachine<Stream>),
109 Done(FinalResult),
110 }
111
112 /// Derive the `Sec-WebSocket-Accept` response header from a `Sec-WebSocket-Key` request header.
113 ///
114 /// This function can be used to perform a handshake before passing a raw TCP stream to
115 /// [`WebSocket::from_raw_socket`][crate::protocol::WebSocket::from_raw_socket].
derive_accept_key(request_key: &[u8]) -> String116 pub fn derive_accept_key(request_key: &[u8]) -> String {
117 // ... field is constructed by concatenating /key/ ...
118 // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
119 const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
120 let mut sha1 = Sha1::default();
121 sha1.update(request_key);
122 sha1.update(WS_GUID);
123 data_encoding::BASE64.encode(&sha1.finalize())
124 }
125
126 #[cfg(test)]
127 mod tests {
128 use super::derive_accept_key;
129
130 #[test]
key_conversion()131 fn key_conversion() {
132 // example from RFC 6455
133 assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
134 }
135 }
136