1 use crate::frame::{self, Frame, Kind, Reason};
2 use crate::frame::{
3     DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
4 };
5 use crate::proto::Error;
6 
7 use crate::hpack;
8 
9 use futures_core::Stream;
10 
11 use bytes::BytesMut;
12 
13 use std::io;
14 
15 use std::pin::Pin;
16 use std::task::{Context, Poll};
17 use tokio::io::AsyncRead;
18 use tokio_util::codec::FramedRead as InnerFramedRead;
19 use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
20 
21 // 16 MB "sane default" taken from golang http2
22 const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
23 
24 #[derive(Debug)]
25 pub struct FramedRead<T> {
26     inner: InnerFramedRead<T, LengthDelimitedCodec>,
27 
28     // hpack decoder state
29     hpack: hpack::Decoder,
30 
31     max_header_list_size: usize,
32 
33     max_continuation_frames: usize,
34 
35     partial: Option<Partial>,
36 }
37 
38 /// Partially loaded headers frame
39 #[derive(Debug)]
40 struct Partial {
41     /// Empty frame
42     frame: Continuable,
43 
44     /// Partial header payload
45     buf: BytesMut,
46 
47     continuation_frames_count: usize,
48 }
49 
50 #[derive(Debug)]
51 enum Continuable {
52     Headers(frame::Headers),
53     PushPromise(frame::PushPromise),
54 }
55 
56 impl<T> FramedRead<T> {
new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T>57     pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
58         let max_header_list_size = DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE;
59         let max_continuation_frames =
60             calc_max_continuation_frames(max_header_list_size, inner.decoder().max_frame_length());
61         FramedRead {
62             inner,
63             hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
64             max_header_list_size,
65             max_continuation_frames,
66             partial: None,
67         }
68     }
69 
get_ref(&self) -> &T70     pub fn get_ref(&self) -> &T {
71         self.inner.get_ref()
72     }
73 
get_mut(&mut self) -> &mut T74     pub fn get_mut(&mut self) -> &mut T {
75         self.inner.get_mut()
76     }
77 
78     /// Returns the current max frame size setting
79     #[inline]
max_frame_size(&self) -> usize80     pub fn max_frame_size(&self) -> usize {
81         self.inner.decoder().max_frame_length()
82     }
83 
84     /// Updates the max frame size setting.
85     ///
86     /// Must be within 16,384 and 16,777,215.
87     #[inline]
set_max_frame_size(&mut self, val: usize)88     pub fn set_max_frame_size(&mut self, val: usize) {
89         assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
90         self.inner.decoder_mut().set_max_frame_length(val);
91         // Update max CONTINUATION frames too, since its based on this
92         self.max_continuation_frames = calc_max_continuation_frames(self.max_header_list_size, val);
93     }
94 
95     /// Update the max header list size setting.
96     #[inline]
set_max_header_list_size(&mut self, val: usize)97     pub fn set_max_header_list_size(&mut self, val: usize) {
98         self.max_header_list_size = val;
99         // Update max CONTINUATION frames too, since its based on this
100         self.max_continuation_frames = calc_max_continuation_frames(val, self.max_frame_size());
101     }
102 
103     /// Update the header table size setting.
104     #[inline]
set_header_table_size(&mut self, val: usize)105     pub fn set_header_table_size(&mut self, val: usize) {
106         self.hpack.queue_size_update(val);
107     }
108 }
109 
calc_max_continuation_frames(header_max: usize, frame_max: usize) -> usize110 fn calc_max_continuation_frames(header_max: usize, frame_max: usize) -> usize {
111     // At least this many frames needed to use max header list size
112     let min_frames_for_list = (header_max / frame_max).max(1);
113     // Some padding for imperfectly packed frames
114     // 25% without floats
115     let padding = min_frames_for_list >> 2;
116     min_frames_for_list.saturating_add(padding).max(5)
117 }
118 
119 /// Decodes a frame.
120 ///
121 /// This method is intentionally de-generified and outlined because it is very large.
decode_frame( hpack: &mut hpack::Decoder, max_header_list_size: usize, max_continuation_frames: usize, partial_inout: &mut Option<Partial>, mut bytes: BytesMut, ) -> Result<Option<Frame>, Error>122 fn decode_frame(
123     hpack: &mut hpack::Decoder,
124     max_header_list_size: usize,
125     max_continuation_frames: usize,
126     partial_inout: &mut Option<Partial>,
127     mut bytes: BytesMut,
128 ) -> Result<Option<Frame>, Error> {
129     let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
130     let _e = span.enter();
131 
132     tracing::trace!("decoding frame from {}B", bytes.len());
133 
134     // Parse the head
135     let head = frame::Head::parse(&bytes);
136 
137     if partial_inout.is_some() && head.kind() != Kind::Continuation {
138         proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
139         return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
140     }
141 
142     let kind = head.kind();
143 
144     tracing::trace!(frame.kind = ?kind);
145 
146     macro_rules! header_block {
147         ($frame:ident, $head:ident, $bytes:ident) => ({
148             // Drop the frame header
149             // TODO: Change to drain: carllerche/bytes#130
150             let _ = $bytes.split_to(frame::HEADER_LEN);
151 
152             // Parse the header frame w/o parsing the payload
153             let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
154                 Ok(res) => res,
155                 Err(frame::Error::InvalidDependencyId) => {
156                     proto_err!(stream: "invalid HEADERS dependency ID");
157                     // A stream cannot depend on itself. An endpoint MUST
158                     // treat this as a stream error (Section 5.4.2) of type
159                     // `PROTOCOL_ERROR`.
160                     return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR));
161                 },
162                 Err(e) => {
163                     proto_err!(conn: "failed to load frame; err={:?}", e);
164                     return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
165                 }
166             };
167 
168             let is_end_headers = frame.is_end_headers();
169 
170             // Load the HPACK encoded headers
171             match frame.load_hpack(&mut payload, max_header_list_size, hpack) {
172                 Ok(_) => {},
173                 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
174                 Err(frame::Error::MalformedMessage) => {
175                     let id = $head.stream_id();
176                     proto_err!(stream: "malformed header block; stream={:?}", id);
177                     return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
178                 },
179                 Err(e) => {
180                     proto_err!(conn: "failed HPACK decoding; err={:?}", e);
181                     return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
182                 }
183             }
184 
185             if is_end_headers {
186                 frame.into()
187             } else {
188                 tracing::trace!("loaded partial header block");
189                 // Defer returning the frame
190                 *partial_inout = Some(Partial {
191                     frame: Continuable::$frame(frame),
192                     buf: payload,
193                     continuation_frames_count: 0,
194                 });
195 
196                 return Ok(None);
197             }
198         });
199     }
200 
201     let frame = match kind {
202         Kind::Settings => {
203             let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
204 
205             res.map_err(|e| {
206                 proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
207                 Error::library_go_away(Reason::PROTOCOL_ERROR)
208             })?
209             .into()
210         }
211         Kind::Ping => {
212             let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
213 
214             res.map_err(|e| {
215                 proto_err!(conn: "failed to load PING frame; err={:?}", e);
216                 Error::library_go_away(Reason::PROTOCOL_ERROR)
217             })?
218             .into()
219         }
220         Kind::WindowUpdate => {
221             let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
222 
223             res.map_err(|e| {
224                 proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
225                 Error::library_go_away(Reason::PROTOCOL_ERROR)
226             })?
227             .into()
228         }
229         Kind::Data => {
230             let _ = bytes.split_to(frame::HEADER_LEN);
231             let res = frame::Data::load(head, bytes.freeze());
232 
233             // TODO: Should this always be connection level? Probably not...
234             res.map_err(|e| {
235                 proto_err!(conn: "failed to load DATA frame; err={:?}", e);
236                 Error::library_go_away(Reason::PROTOCOL_ERROR)
237             })?
238             .into()
239         }
240         Kind::Headers => header_block!(Headers, head, bytes),
241         Kind::Reset => {
242             let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
243             res.map_err(|e| {
244                 proto_err!(conn: "failed to load RESET frame; err={:?}", e);
245                 Error::library_go_away(Reason::PROTOCOL_ERROR)
246             })?
247             .into()
248         }
249         Kind::GoAway => {
250             let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
251             res.map_err(|e| {
252                 proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
253                 Error::library_go_away(Reason::PROTOCOL_ERROR)
254             })?
255             .into()
256         }
257         Kind::PushPromise => header_block!(PushPromise, head, bytes),
258         Kind::Priority => {
259             if head.stream_id() == 0 {
260                 // Invalid stream identifier
261                 proto_err!(conn: "invalid stream ID 0");
262                 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
263             }
264 
265             match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
266                 Ok(frame) => frame.into(),
267                 Err(frame::Error::InvalidDependencyId) => {
268                     // A stream cannot depend on itself. An endpoint MUST
269                     // treat this as a stream error (Section 5.4.2) of type
270                     // `PROTOCOL_ERROR`.
271                     let id = head.stream_id();
272                     proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
273                     return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
274                 }
275                 Err(e) => {
276                     proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
277                     return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
278                 }
279             }
280         }
281         Kind::Continuation => {
282             let is_end_headers = (head.flag() & 0x4) == 0x4;
283 
284             let mut partial = match partial_inout.take() {
285                 Some(partial) => partial,
286                 None => {
287                     proto_err!(conn: "received unexpected CONTINUATION frame");
288                     return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
289                 }
290             };
291 
292             // The stream identifiers must match
293             if partial.frame.stream_id() != head.stream_id() {
294                 proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
295                 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
296             }
297 
298             // Check for CONTINUATION flood
299             if is_end_headers {
300                 partial.continuation_frames_count = 0;
301             } else {
302                 let cnt = partial.continuation_frames_count + 1;
303                 if cnt > max_continuation_frames {
304                     tracing::debug!("too_many_continuations, max = {}", max_continuation_frames);
305                     return Err(Error::library_go_away_data(
306                         Reason::ENHANCE_YOUR_CALM,
307                         "too_many_continuations",
308                     ));
309                 } else {
310                     partial.continuation_frames_count = cnt;
311                 }
312             }
313 
314             // Extend the buf
315             if partial.buf.is_empty() {
316                 partial.buf = bytes.split_off(frame::HEADER_LEN);
317             } else {
318                 if partial.frame.is_over_size() {
319                     // If there was left over bytes previously, they may be
320                     // needed to continue decoding, even though we will
321                     // be ignoring this frame. This is done to keep the HPACK
322                     // decoder state up-to-date.
323                     //
324                     // Still, we need to be careful, because if a malicious
325                     // attacker were to try to send a gigantic string, such
326                     // that it fits over multiple header blocks, we could
327                     // grow memory uncontrollably again, and that'd be a shame.
328                     //
329                     // Instead, we use a simple heuristic to determine if
330                     // we should continue to ignore decoding, or to tell
331                     // the attacker to go away.
332                     if partial.buf.len() + bytes.len() > max_header_list_size {
333                         proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
334                         return Err(Error::library_go_away(Reason::COMPRESSION_ERROR));
335                     }
336                 }
337                 partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
338             }
339 
340             match partial
341                 .frame
342                 .load_hpack(&mut partial.buf, max_header_list_size, hpack)
343             {
344                 Ok(_) => {}
345                 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}
346                 Err(frame::Error::MalformedMessage) => {
347                     let id = head.stream_id();
348                     proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
349                     return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
350                 }
351                 Err(e) => {
352                     proto_err!(conn: "failed HPACK decoding; err={:?}", e);
353                     return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
354                 }
355             }
356 
357             if is_end_headers {
358                 partial.frame.into()
359             } else {
360                 *partial_inout = Some(partial);
361                 return Ok(None);
362             }
363         }
364         Kind::Unknown => {
365             // Unknown frames are ignored
366             return Ok(None);
367         }
368     };
369 
370     Ok(Some(frame))
371 }
372 
373 impl<T> Stream for FramedRead<T>
374 where
375     T: AsyncRead + Unpin,
376 {
377     type Item = Result<Frame, Error>;
378 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>379     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
380         let span = tracing::trace_span!("FramedRead::poll_next");
381         let _e = span.enter();
382         loop {
383             tracing::trace!("poll");
384             let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
385                 Some(Ok(bytes)) => bytes,
386                 Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
387                 None => return Poll::Ready(None),
388             };
389 
390             tracing::trace!(read.bytes = bytes.len());
391             let Self {
392                 ref mut hpack,
393                 max_header_list_size,
394                 ref mut partial,
395                 max_continuation_frames,
396                 ..
397             } = *self;
398             if let Some(frame) = decode_frame(
399                 hpack,
400                 max_header_list_size,
401                 max_continuation_frames,
402                 partial,
403                 bytes,
404             )? {
405                 tracing::debug!(?frame, "received");
406                 return Poll::Ready(Some(Ok(frame)));
407             }
408         }
409     }
410 }
411 
map_err(err: io::Error) -> Error412 fn map_err(err: io::Error) -> Error {
413     if let io::ErrorKind::InvalidData = err.kind() {
414         if let Some(custom) = err.get_ref() {
415             if custom.is::<LengthDelimitedCodecError>() {
416                 return Error::library_go_away(Reason::FRAME_SIZE_ERROR);
417             }
418         }
419     }
420     err.into()
421 }
422 
423 // ===== impl Continuable =====
424 
425 impl Continuable {
stream_id(&self) -> frame::StreamId426     fn stream_id(&self) -> frame::StreamId {
427         match *self {
428             Continuable::Headers(ref h) => h.stream_id(),
429             Continuable::PushPromise(ref p) => p.stream_id(),
430         }
431     }
432 
is_over_size(&self) -> bool433     fn is_over_size(&self) -> bool {
434         match *self {
435             Continuable::Headers(ref h) => h.is_over_size(),
436             Continuable::PushPromise(ref p) => p.is_over_size(),
437         }
438     }
439 
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), frame::Error>440     fn load_hpack(
441         &mut self,
442         src: &mut BytesMut,
443         max_header_list_size: usize,
444         decoder: &mut hpack::Decoder,
445     ) -> Result<(), frame::Error> {
446         match *self {
447             Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
448             Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
449         }
450     }
451 }
452 
453 impl<T> From<Continuable> for Frame<T> {
from(cont: Continuable) -> Self454     fn from(cont: Continuable) -> Self {
455         match cont {
456             Continuable::Headers(mut headers) => {
457                 headers.set_end_headers();
458                 headers.into()
459             }
460             Continuable::PushPromise(mut push) => {
461                 push.set_end_headers();
462                 push.into()
463             }
464         }
465     }
466 }
467