1 //! WebSocket handshake machine. 2 3 use bytes::Buf; 4 use log::*; 5 use std::io::{Cursor, Read, Write}; 6 7 use crate::{ 8 error::{Error, ProtocolError, Result}, 9 util::NonBlockingResult, 10 ReadBuffer, 11 }; 12 13 /// A generic handshake state machine. 14 #[derive(Debug)] 15 pub struct HandshakeMachine<Stream> { 16 stream: Stream, 17 state: HandshakeState, 18 } 19 20 impl<Stream> HandshakeMachine<Stream> { 21 /// Start reading data from the peer. start_read(stream: Stream) -> Self22 pub fn start_read(stream: Stream) -> Self { 23 Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) } 24 } 25 /// Start writing data to the peer. start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self26 pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self { 27 HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) } 28 } 29 /// Returns a shared reference to the inner stream. get_ref(&self) -> &Stream30 pub fn get_ref(&self) -> &Stream { 31 &self.stream 32 } 33 /// Returns a mutable reference to the inner stream. get_mut(&mut self) -> &mut Stream34 pub fn get_mut(&mut self) -> &mut Stream { 35 &mut self.stream 36 } 37 } 38 39 impl<Stream: Read + Write> HandshakeMachine<Stream> { 40 /// Perform a single handshake round. single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>>41 pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> { 42 trace!("Doing handshake round."); 43 match self.state { 44 HandshakeState::Reading(mut buf, mut attack_check) => { 45 let read = buf.read_from(&mut self.stream).no_block()?; 46 match read { 47 Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)), 48 Some(count) => { 49 attack_check.check_incoming_packet_size(count)?; 50 // TODO: this is slow for big headers with too many small packets. 51 // The parser has to be reworked in order to work on streams instead 52 // of buffers. 53 Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? { 54 buf.advance(size); 55 RoundResult::StageFinished(StageResult::DoneReading { 56 result: obj, 57 stream: self.stream, 58 tail: buf.into_vec(), 59 }) 60 } else { 61 RoundResult::Incomplete(HandshakeMachine { 62 state: HandshakeState::Reading(buf, attack_check), 63 ..self 64 }) 65 }) 66 } 67 None => Ok(RoundResult::WouldBlock(HandshakeMachine { 68 state: HandshakeState::Reading(buf, attack_check), 69 ..self 70 })), 71 } 72 } 73 HandshakeState::Writing(mut buf) => { 74 assert!(buf.has_remaining()); 75 if let Some(size) = self.stream.write(Buf::chunk(&buf)).no_block()? { 76 assert!(size > 0); 77 buf.advance(size); 78 Ok(if buf.has_remaining() { 79 RoundResult::Incomplete(HandshakeMachine { 80 state: HandshakeState::Writing(buf), 81 ..self 82 }) 83 } else { 84 RoundResult::StageFinished(StageResult::DoneWriting(self.stream)) 85 }) 86 } else { 87 Ok(RoundResult::WouldBlock(HandshakeMachine { 88 state: HandshakeState::Writing(buf), 89 ..self 90 })) 91 } 92 } 93 } 94 } 95 } 96 97 /// The result of the round. 98 #[derive(Debug)] 99 pub enum RoundResult<Obj, Stream> { 100 /// Round not done, I/O would block. 101 WouldBlock(HandshakeMachine<Stream>), 102 /// Round done, state unchanged. 103 Incomplete(HandshakeMachine<Stream>), 104 /// Stage complete. 105 StageFinished(StageResult<Obj, Stream>), 106 } 107 108 /// The result of the stage. 109 #[derive(Debug)] 110 pub enum StageResult<Obj, Stream> { 111 /// Reading round finished. 112 #[allow(missing_docs)] 113 DoneReading { result: Obj, stream: Stream, tail: Vec<u8> }, 114 /// Writing round finished. 115 DoneWriting(Stream), 116 } 117 118 /// The parseable object. 119 pub trait TryParse: Sized { 120 /// Return Ok(None) if incomplete, Err on syntax error. try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>121 fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>; 122 } 123 124 /// The handshake state. 125 #[derive(Debug)] 126 enum HandshakeState { 127 /// Reading data from the peer. 128 Reading(ReadBuffer, AttackCheck), 129 /// Sending data to the peer. 130 Writing(Cursor<Vec<u8>>), 131 } 132 133 /// Attack mitigation. Contains counters needed to prevent DoS attacks 134 /// and reject valid but useless headers. 135 #[derive(Debug)] 136 pub(crate) struct AttackCheck { 137 /// Number of HTTP header successful reads (TCP packets). 138 number_of_packets: usize, 139 /// Total number of bytes in HTTP header. 140 number_of_bytes: usize, 141 } 142 143 impl AttackCheck { 144 /// Initialize attack checking for incoming buffer. new() -> Self145 fn new() -> Self { 146 Self { number_of_packets: 0, number_of_bytes: 0 } 147 } 148 149 /// Check the size of an incoming packet. To be called immediately after `read()` 150 /// passing its returned bytes count as `size`. check_incoming_packet_size(&mut self, size: usize) -> Result<()>151 fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> { 152 self.number_of_packets += 1; 153 self.number_of_bytes += size; 154 155 // TODO: these values are hardcoded. Instead of making them configurable, 156 // rework the way HTTP header is parsed to remove this check at all. 157 const MAX_BYTES: usize = 65536; 158 const MAX_PACKETS: usize = 512; 159 const MIN_PACKET_SIZE: usize = 128; 160 const MIN_PACKET_CHECK_THRESHOLD: usize = 64; 161 162 if self.number_of_bytes > MAX_BYTES { 163 return Err(Error::AttackAttempt); 164 } 165 166 if self.number_of_packets > MAX_PACKETS { 167 return Err(Error::AttackAttempt); 168 } 169 170 if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD 171 && self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes 172 { 173 return Err(Error::AttackAttempt); 174 } 175 176 Ok(()) 177 } 178 } 179