1 //! WebSocket handshake control.
2 
3 pub mod client;
4 pub mod headers;
5 pub mod machine;
6 pub mod server;
7 
8 use std::{
9     error::Error as ErrorTrait,
10     fmt,
11     io::{Read, Write},
12 };
13 
14 use sha1::{Digest, Sha1};
15 
16 use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
17 use crate::error::Error;
18 
19 /// A WebSocket handshake.
20 #[derive(Debug)]
21 pub struct MidHandshake<Role: HandshakeRole> {
22     role: Role,
23     machine: HandshakeMachine<Role::InternalStream>,
24 }
25 
26 impl<Role: HandshakeRole> MidHandshake<Role> {
27     /// Allow access to machine
get_ref(&self) -> &HandshakeMachine<Role::InternalStream>28     pub fn get_ref(&self) -> &HandshakeMachine<Role::InternalStream> {
29         &self.machine
30     }
31 
32     /// Allow mutable access to machine
get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream>33     pub fn get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream> {
34         &mut self.machine
35     }
36 
37     /// Restarts the handshake process.
handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>>38     pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
39         let mut mach = self.machine;
40         loop {
41             mach = match mach.single_round()? {
42                 RoundResult::WouldBlock(m) => {
43                     return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
44                 }
45                 RoundResult::Incomplete(m) => m,
46                 RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
47                     ProcessingResult::Continue(m) => m,
48                     ProcessingResult::Done(result) => return Ok(result),
49                 },
50             }
51         }
52     }
53 }
54 
55 /// A handshake result.
56 pub enum HandshakeError<Role: HandshakeRole> {
57     /// Handshake was interrupted (would block).
58     Interrupted(MidHandshake<Role>),
59     /// Handshake failed.
60     Failure(Error),
61 }
62 
63 impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result64     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65         match *self {
66             HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
67             HandshakeError::Failure(ref e) => write!(f, "HandshakeError::Failure({:?})", e),
68         }
69     }
70 }
71 
72 impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result73     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74         match *self {
75             HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
76             HandshakeError::Failure(ref e) => write!(f, "{}", e),
77         }
78     }
79 }
80 
81 impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {}
82 
83 impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
from(err: Error) -> Self84     fn from(err: Error) -> Self {
85         HandshakeError::Failure(err)
86     }
87 }
88 
89 /// Handshake role.
90 pub trait HandshakeRole {
91     #[doc(hidden)]
92     type IncomingData: TryParse;
93     #[doc(hidden)]
94     type InternalStream: Read + Write;
95     #[doc(hidden)]
96     type FinalResult;
97     #[doc(hidden)]
stage_finished( &mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>, ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>98     fn stage_finished(
99         &mut self,
100         finish: StageResult<Self::IncomingData, Self::InternalStream>,
101     ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
102 }
103 
104 /// Stage processing result.
105 #[doc(hidden)]
106 #[derive(Debug)]
107 pub enum ProcessingResult<Stream, FinalResult> {
108     Continue(HandshakeMachine<Stream>),
109     Done(FinalResult),
110 }
111 
112 /// Derive the `Sec-WebSocket-Accept` response header from a `Sec-WebSocket-Key` request header.
113 ///
114 /// This function can be used to perform a handshake before passing a raw TCP stream to
115 /// [`WebSocket::from_raw_socket`][crate::protocol::WebSocket::from_raw_socket].
derive_accept_key(request_key: &[u8]) -> String116 pub fn derive_accept_key(request_key: &[u8]) -> String {
117     // ... field is constructed by concatenating /key/ ...
118     // ... with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
119     const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
120     let mut sha1 = Sha1::default();
121     sha1.update(request_key);
122     sha1.update(WS_GUID);
123     data_encoding::BASE64.encode(&sha1.finalize())
124 }
125 
126 #[cfg(test)]
127 mod tests {
128     use super::derive_accept_key;
129 
130     #[test]
key_conversion()131     fn key_conversion() {
132         // example from RFC 6455
133         assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
134     }
135 }
136