1 use crate::{Body, SizeHint};
2 use bytes::Buf;
3 use http::HeaderMap;
4 use pin_project_lite::pin_project;
5 use std::error::Error;
6 use std::fmt;
7 use std::pin::Pin;
8 use std::task::{Context, Poll};
9 
10 pin_project! {
11     /// A length limited body.
12     ///
13     /// This body will return an error if more than the configured number
14     /// of bytes are returned on polling the wrapped body.
15     #[derive(Clone, Copy, Debug)]
16     pub struct Limited<B> {
17         remaining: usize,
18         #[pin]
19         inner: B,
20     }
21 }
22 
23 impl<B> Limited<B> {
24     /// Create a new `Limited`.
new(inner: B, limit: usize) -> Self25     pub fn new(inner: B, limit: usize) -> Self {
26         Self {
27             remaining: limit,
28             inner,
29         }
30     }
31 }
32 
33 impl<B> Body for Limited<B>
34 where
35     B: Body,
36     B::Error: Into<Box<dyn Error + Send + Sync>>,
37 {
38     type Data = B::Data;
39     type Error = Box<dyn Error + Send + Sync>;
40 
poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>41     fn poll_data(
42         self: Pin<&mut Self>,
43         cx: &mut Context<'_>,
44     ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
45         let this = self.project();
46         let res = match this.inner.poll_data(cx) {
47             Poll::Pending => return Poll::Pending,
48             Poll::Ready(None) => None,
49             Poll::Ready(Some(Ok(data))) => {
50                 if data.remaining() > *this.remaining {
51                     *this.remaining = 0;
52                     Some(Err(LengthLimitError.into()))
53                 } else {
54                     *this.remaining -= data.remaining();
55                     Some(Ok(data))
56                 }
57             }
58             Poll::Ready(Some(Err(err))) => Some(Err(err.into())),
59         };
60 
61         Poll::Ready(res)
62     }
63 
poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Self::Error>>64     fn poll_trailers(
65         self: Pin<&mut Self>,
66         cx: &mut Context<'_>,
67     ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
68         let this = self.project();
69         let res = match this.inner.poll_trailers(cx) {
70             Poll::Pending => return Poll::Pending,
71             Poll::Ready(Ok(data)) => Ok(data),
72             Poll::Ready(Err(err)) => Err(err.into()),
73         };
74 
75         Poll::Ready(res)
76     }
77 
is_end_stream(&self) -> bool78     fn is_end_stream(&self) -> bool {
79         self.inner.is_end_stream()
80     }
81 
size_hint(&self) -> SizeHint82     fn size_hint(&self) -> SizeHint {
83         use std::convert::TryFrom;
84         match u64::try_from(self.remaining) {
85             Ok(n) => {
86                 let mut hint = self.inner.size_hint();
87                 if hint.lower() >= n {
88                     hint.set_exact(n)
89                 } else if let Some(max) = hint.upper() {
90                     hint.set_upper(n.min(max))
91                 } else {
92                     hint.set_upper(n)
93                 }
94                 hint
95             }
96             Err(_) => self.inner.size_hint(),
97         }
98     }
99 }
100 
101 /// An error returned when body length exceeds the configured limit.
102 #[derive(Debug)]
103 #[non_exhaustive]
104 pub struct LengthLimitError;
105 
106 impl fmt::Display for LengthLimitError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result107     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108         f.write_str("length limit exceeded")
109     }
110 }
111 
112 impl Error for LengthLimitError {}
113 
114 #[cfg(test)]
115 mod tests {
116     use super::*;
117     use crate::Full;
118     use bytes::Bytes;
119     use std::convert::Infallible;
120 
121     #[tokio::test]
read_for_body_under_limit_returns_data()122     async fn read_for_body_under_limit_returns_data() {
123         const DATA: &[u8] = b"testing";
124         let inner = Full::new(Bytes::from(DATA));
125         let body = &mut Limited::new(inner, 8);
126 
127         let mut hint = SizeHint::new();
128         hint.set_upper(7);
129         assert_eq!(body.size_hint().upper(), hint.upper());
130 
131         let data = body.data().await.unwrap().unwrap();
132         assert_eq!(data, DATA);
133         hint.set_upper(0);
134         assert_eq!(body.size_hint().upper(), hint.upper());
135 
136         assert!(matches!(body.data().await, None));
137     }
138 
139     #[tokio::test]
read_for_body_over_limit_returns_error()140     async fn read_for_body_over_limit_returns_error() {
141         const DATA: &[u8] = b"testing a string that is too long";
142         let inner = Full::new(Bytes::from(DATA));
143         let body = &mut Limited::new(inner, 8);
144 
145         let mut hint = SizeHint::new();
146         hint.set_upper(8);
147         assert_eq!(body.size_hint().upper(), hint.upper());
148 
149         let error = body.data().await.unwrap().unwrap_err();
150         assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
151     }
152 
153     struct Chunky(&'static [&'static [u8]]);
154 
155     impl Body for Chunky {
156         type Data = &'static [u8];
157         type Error = Infallible;
158 
poll_data( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>159         fn poll_data(
160             self: Pin<&mut Self>,
161             _cx: &mut Context<'_>,
162         ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
163             let mut this = self;
164             match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) {
165                 Some((data, new_tail)) => {
166                     this.0 = new_tail;
167 
168                     Poll::Ready(Some(data))
169                 }
170                 None => Poll::Ready(None),
171             }
172         }
173 
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Self::Error>>174         fn poll_trailers(
175             self: Pin<&mut Self>,
176             _cx: &mut Context<'_>,
177         ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
178             Poll::Ready(Ok(Some(HeaderMap::new())))
179         }
180     }
181 
182     #[tokio::test]
read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk( )183     async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk(
184     ) {
185         const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"];
186         let inner = Chunky(DATA);
187         let body = &mut Limited::new(inner, 8);
188 
189         let mut hint = SizeHint::new();
190         hint.set_upper(8);
191         assert_eq!(body.size_hint().upper(), hint.upper());
192 
193         let data = body.data().await.unwrap().unwrap();
194         assert_eq!(data, DATA[0]);
195         hint.set_upper(0);
196         assert_eq!(body.size_hint().upper(), hint.upper());
197 
198         let error = body.data().await.unwrap().unwrap_err();
199         assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
200     }
201 
202     #[tokio::test]
read_for_chunked_body_over_limit_on_first_chunk_returns_error()203     async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() {
204         const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"];
205         let inner = Chunky(DATA);
206         let body = &mut Limited::new(inner, 8);
207 
208         let mut hint = SizeHint::new();
209         hint.set_upper(8);
210         assert_eq!(body.size_hint().upper(), hint.upper());
211 
212         let error = body.data().await.unwrap().unwrap_err();
213         assert!(matches!(error.downcast_ref(), Some(LengthLimitError)));
214     }
215 
216     #[tokio::test]
read_for_chunked_body_under_limit_is_okay()217     async fn read_for_chunked_body_under_limit_is_okay() {
218         const DATA: &[&[u8]] = &[b"test", b"ing!"];
219         let inner = Chunky(DATA);
220         let body = &mut Limited::new(inner, 8);
221 
222         let mut hint = SizeHint::new();
223         hint.set_upper(8);
224         assert_eq!(body.size_hint().upper(), hint.upper());
225 
226         let data = body.data().await.unwrap().unwrap();
227         assert_eq!(data, DATA[0]);
228         hint.set_upper(4);
229         assert_eq!(body.size_hint().upper(), hint.upper());
230 
231         let data = body.data().await.unwrap().unwrap();
232         assert_eq!(data, DATA[1]);
233         hint.set_upper(0);
234         assert_eq!(body.size_hint().upper(), hint.upper());
235 
236         assert!(matches!(body.data().await, None));
237     }
238 
239     #[tokio::test]
read_for_trailers_propagates_inner_trailers()240     async fn read_for_trailers_propagates_inner_trailers() {
241         const DATA: &[&[u8]] = &[b"test", b"ing!"];
242         let inner = Chunky(DATA);
243         let body = &mut Limited::new(inner, 8);
244         let trailers = body.trailers().await.unwrap();
245         assert_eq!(trailers, Some(HeaderMap::new()))
246     }
247 
248     #[derive(Debug)]
249     enum ErrorBodyError {
250         Data,
251         Trailers,
252     }
253 
254     impl fmt::Display for ErrorBodyError {
fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result255         fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
256             Ok(())
257         }
258     }
259 
260     impl Error for ErrorBodyError {}
261 
262     struct ErrorBody;
263 
264     impl Body for ErrorBody {
265         type Data = &'static [u8];
266         type Error = ErrorBodyError;
267 
poll_data( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>268         fn poll_data(
269             self: Pin<&mut Self>,
270             _cx: &mut Context<'_>,
271         ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
272             Poll::Ready(Some(Err(ErrorBodyError::Data)))
273         }
274 
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Self::Error>>275         fn poll_trailers(
276             self: Pin<&mut Self>,
277             _cx: &mut Context<'_>,
278         ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
279             Poll::Ready(Err(ErrorBodyError::Trailers))
280         }
281     }
282 
283     #[tokio::test]
read_for_body_returning_error_propagates_error()284     async fn read_for_body_returning_error_propagates_error() {
285         let body = &mut Limited::new(ErrorBody, 8);
286         let error = body.data().await.unwrap().unwrap_err();
287         assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data)));
288     }
289 
290     #[tokio::test]
trailers_for_body_returning_error_propagates_error()291     async fn trailers_for_body_returning_error_propagates_error() {
292         let body = &mut Limited::new(ErrorBody, 8);
293         let error = body.trailers().await.unwrap_err();
294         assert!(matches!(
295             error.downcast_ref(),
296             Some(ErrorBodyError::Trailers)
297         ));
298     }
299 }
300