1 use crate::{Body, SizeHint}; 2 use bytes::Buf; 3 use http::HeaderMap; 4 use pin_project_lite::pin_project; 5 use std::error::Error; 6 use std::fmt; 7 use std::pin::Pin; 8 use std::task::{Context, Poll}; 9 10 pin_project! { 11 /// A length limited body. 12 /// 13 /// This body will return an error if more than the configured number 14 /// of bytes are returned on polling the wrapped body. 15 #[derive(Clone, Copy, Debug)] 16 pub struct Limited<B> { 17 remaining: usize, 18 #[pin] 19 inner: B, 20 } 21 } 22 23 impl<B> Limited<B> { 24 /// Create a new `Limited`. new(inner: B, limit: usize) -> Self25 pub fn new(inner: B, limit: usize) -> Self { 26 Self { 27 remaining: limit, 28 inner, 29 } 30 } 31 } 32 33 impl<B> Body for Limited<B> 34 where 35 B: Body, 36 B::Error: Into<Box<dyn Error + Send + Sync>>, 37 { 38 type Data = B::Data; 39 type Error = Box<dyn Error + Send + Sync>; 40 poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>41 fn poll_data( 42 self: Pin<&mut Self>, 43 cx: &mut Context<'_>, 44 ) -> Poll<Option<Result<Self::Data, Self::Error>>> { 45 let this = self.project(); 46 let res = match this.inner.poll_data(cx) { 47 Poll::Pending => return Poll::Pending, 48 Poll::Ready(None) => None, 49 Poll::Ready(Some(Ok(data))) => { 50 if data.remaining() > *this.remaining { 51 *this.remaining = 0; 52 Some(Err(LengthLimitError.into())) 53 } else { 54 *this.remaining -= data.remaining(); 55 Some(Ok(data)) 56 } 57 } 58 Poll::Ready(Some(Err(err))) => Some(Err(err.into())), 59 }; 60 61 Poll::Ready(res) 62 } 63 poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Self::Error>>64 fn poll_trailers( 65 self: Pin<&mut Self>, 66 cx: &mut Context<'_>, 67 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { 68 let this = self.project(); 69 let res = match this.inner.poll_trailers(cx) { 70 Poll::Pending => return Poll::Pending, 71 Poll::Ready(Ok(data)) => Ok(data), 72 Poll::Ready(Err(err)) => Err(err.into()), 73 }; 74 75 Poll::Ready(res) 76 } 77 is_end_stream(&self) -> bool78 fn is_end_stream(&self) -> bool { 79 self.inner.is_end_stream() 80 } 81 size_hint(&self) -> SizeHint82 fn size_hint(&self) -> SizeHint { 83 use std::convert::TryFrom; 84 match u64::try_from(self.remaining) { 85 Ok(n) => { 86 let mut hint = self.inner.size_hint(); 87 if hint.lower() >= n { 88 hint.set_exact(n) 89 } else if let Some(max) = hint.upper() { 90 hint.set_upper(n.min(max)) 91 } else { 92 hint.set_upper(n) 93 } 94 hint 95 } 96 Err(_) => self.inner.size_hint(), 97 } 98 } 99 } 100 101 /// An error returned when body length exceeds the configured limit. 102 #[derive(Debug)] 103 #[non_exhaustive] 104 pub struct LengthLimitError; 105 106 impl fmt::Display for LengthLimitError { fmt(&self, f: &mut fmt::Formatter) -> fmt::Result107 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 108 f.write_str("length limit exceeded") 109 } 110 } 111 112 impl Error for LengthLimitError {} 113 114 #[cfg(test)] 115 mod tests { 116 use super::*; 117 use crate::Full; 118 use bytes::Bytes; 119 use std::convert::Infallible; 120 121 #[tokio::test] read_for_body_under_limit_returns_data()122 async fn read_for_body_under_limit_returns_data() { 123 const DATA: &[u8] = b"testing"; 124 let inner = Full::new(Bytes::from(DATA)); 125 let body = &mut Limited::new(inner, 8); 126 127 let mut hint = SizeHint::new(); 128 hint.set_upper(7); 129 assert_eq!(body.size_hint().upper(), hint.upper()); 130 131 let data = body.data().await.unwrap().unwrap(); 132 assert_eq!(data, DATA); 133 hint.set_upper(0); 134 assert_eq!(body.size_hint().upper(), hint.upper()); 135 136 assert!(matches!(body.data().await, None)); 137 } 138 139 #[tokio::test] read_for_body_over_limit_returns_error()140 async fn read_for_body_over_limit_returns_error() { 141 const DATA: &[u8] = b"testing a string that is too long"; 142 let inner = Full::new(Bytes::from(DATA)); 143 let body = &mut Limited::new(inner, 8); 144 145 let mut hint = SizeHint::new(); 146 hint.set_upper(8); 147 assert_eq!(body.size_hint().upper(), hint.upper()); 148 149 let error = body.data().await.unwrap().unwrap_err(); 150 assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); 151 } 152 153 struct Chunky(&'static [&'static [u8]]); 154 155 impl Body for Chunky { 156 type Data = &'static [u8]; 157 type Error = Infallible; 158 poll_data( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>159 fn poll_data( 160 self: Pin<&mut Self>, 161 _cx: &mut Context<'_>, 162 ) -> Poll<Option<Result<Self::Data, Self::Error>>> { 163 let mut this = self; 164 match this.0.split_first().map(|(&head, tail)| (Ok(head), tail)) { 165 Some((data, new_tail)) => { 166 this.0 = new_tail; 167 168 Poll::Ready(Some(data)) 169 } 170 None => Poll::Ready(None), 171 } 172 } 173 poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Self::Error>>174 fn poll_trailers( 175 self: Pin<&mut Self>, 176 _cx: &mut Context<'_>, 177 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { 178 Poll::Ready(Ok(Some(HeaderMap::new()))) 179 } 180 } 181 182 #[tokio::test] read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk( )183 async fn read_for_chunked_body_around_limit_returns_first_chunk_but_returns_error_on_over_limit_chunk( 184 ) { 185 const DATA: &[&[u8]] = &[b"testing ", b"a string that is too long"]; 186 let inner = Chunky(DATA); 187 let body = &mut Limited::new(inner, 8); 188 189 let mut hint = SizeHint::new(); 190 hint.set_upper(8); 191 assert_eq!(body.size_hint().upper(), hint.upper()); 192 193 let data = body.data().await.unwrap().unwrap(); 194 assert_eq!(data, DATA[0]); 195 hint.set_upper(0); 196 assert_eq!(body.size_hint().upper(), hint.upper()); 197 198 let error = body.data().await.unwrap().unwrap_err(); 199 assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); 200 } 201 202 #[tokio::test] read_for_chunked_body_over_limit_on_first_chunk_returns_error()203 async fn read_for_chunked_body_over_limit_on_first_chunk_returns_error() { 204 const DATA: &[&[u8]] = &[b"testing a string", b" that is too long"]; 205 let inner = Chunky(DATA); 206 let body = &mut Limited::new(inner, 8); 207 208 let mut hint = SizeHint::new(); 209 hint.set_upper(8); 210 assert_eq!(body.size_hint().upper(), hint.upper()); 211 212 let error = body.data().await.unwrap().unwrap_err(); 213 assert!(matches!(error.downcast_ref(), Some(LengthLimitError))); 214 } 215 216 #[tokio::test] read_for_chunked_body_under_limit_is_okay()217 async fn read_for_chunked_body_under_limit_is_okay() { 218 const DATA: &[&[u8]] = &[b"test", b"ing!"]; 219 let inner = Chunky(DATA); 220 let body = &mut Limited::new(inner, 8); 221 222 let mut hint = SizeHint::new(); 223 hint.set_upper(8); 224 assert_eq!(body.size_hint().upper(), hint.upper()); 225 226 let data = body.data().await.unwrap().unwrap(); 227 assert_eq!(data, DATA[0]); 228 hint.set_upper(4); 229 assert_eq!(body.size_hint().upper(), hint.upper()); 230 231 let data = body.data().await.unwrap().unwrap(); 232 assert_eq!(data, DATA[1]); 233 hint.set_upper(0); 234 assert_eq!(body.size_hint().upper(), hint.upper()); 235 236 assert!(matches!(body.data().await, None)); 237 } 238 239 #[tokio::test] read_for_trailers_propagates_inner_trailers()240 async fn read_for_trailers_propagates_inner_trailers() { 241 const DATA: &[&[u8]] = &[b"test", b"ing!"]; 242 let inner = Chunky(DATA); 243 let body = &mut Limited::new(inner, 8); 244 let trailers = body.trailers().await.unwrap(); 245 assert_eq!(trailers, Some(HeaderMap::new())) 246 } 247 248 #[derive(Debug)] 249 enum ErrorBodyError { 250 Data, 251 Trailers, 252 } 253 254 impl fmt::Display for ErrorBodyError { fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result255 fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { 256 Ok(()) 257 } 258 } 259 260 impl Error for ErrorBodyError {} 261 262 struct ErrorBody; 263 264 impl Body for ErrorBody { 265 type Data = &'static [u8]; 266 type Error = ErrorBodyError; 267 poll_data( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>268 fn poll_data( 269 self: Pin<&mut Self>, 270 _cx: &mut Context<'_>, 271 ) -> Poll<Option<Result<Self::Data, Self::Error>>> { 272 Poll::Ready(Some(Err(ErrorBodyError::Data))) 273 } 274 poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<HeaderMap>, Self::Error>>275 fn poll_trailers( 276 self: Pin<&mut Self>, 277 _cx: &mut Context<'_>, 278 ) -> Poll<Result<Option<HeaderMap>, Self::Error>> { 279 Poll::Ready(Err(ErrorBodyError::Trailers)) 280 } 281 } 282 283 #[tokio::test] read_for_body_returning_error_propagates_error()284 async fn read_for_body_returning_error_propagates_error() { 285 let body = &mut Limited::new(ErrorBody, 8); 286 let error = body.data().await.unwrap().unwrap_err(); 287 assert!(matches!(error.downcast_ref(), Some(ErrorBodyError::Data))); 288 } 289 290 #[tokio::test] trailers_for_body_returning_error_propagates_error()291 async fn trailers_for_body_returning_error_propagates_error() { 292 let body = &mut Limited::new(ErrorBody, 8); 293 let error = body.trailers().await.unwrap_err(); 294 assert!(matches!( 295 error.downcast_ref(), 296 Some(ErrorBodyError::Trailers) 297 )); 298 } 299 } 300