1 use crate::codec::decoder::Decoder; 2 use crate::codec::encoder::Encoder; 3 4 use futures_core::Stream; 5 use tokio::io::{AsyncRead, AsyncWrite}; 6 7 use bytes::BytesMut; 8 use futures_sink::Sink; 9 use pin_project_lite::pin_project; 10 use std::borrow::{Borrow, BorrowMut}; 11 use std::io; 12 use std::pin::Pin; 13 use std::task::{ready, Context, Poll}; 14 15 pin_project! { 16 #[derive(Debug)] 17 pub(crate) struct FramedImpl<T, U, State> { 18 #[pin] 19 pub(crate) inner: T, 20 pub(crate) state: State, 21 pub(crate) codec: U, 22 } 23 } 24 25 const INITIAL_CAPACITY: usize = 8 * 1024; 26 27 #[derive(Debug)] 28 pub(crate) struct ReadFrame { 29 pub(crate) eof: bool, 30 pub(crate) is_readable: bool, 31 pub(crate) buffer: BytesMut, 32 pub(crate) has_errored: bool, 33 } 34 35 pub(crate) struct WriteFrame { 36 pub(crate) buffer: BytesMut, 37 pub(crate) backpressure_boundary: usize, 38 } 39 40 #[derive(Default)] 41 pub(crate) struct RWFrames { 42 pub(crate) read: ReadFrame, 43 pub(crate) write: WriteFrame, 44 } 45 46 impl Default for ReadFrame { default() -> Self47 fn default() -> Self { 48 Self { 49 eof: false, 50 is_readable: false, 51 buffer: BytesMut::with_capacity(INITIAL_CAPACITY), 52 has_errored: false, 53 } 54 } 55 } 56 57 impl Default for WriteFrame { default() -> Self58 fn default() -> Self { 59 Self { 60 buffer: BytesMut::with_capacity(INITIAL_CAPACITY), 61 backpressure_boundary: INITIAL_CAPACITY, 62 } 63 } 64 } 65 66 impl From<BytesMut> for ReadFrame { from(mut buffer: BytesMut) -> Self67 fn from(mut buffer: BytesMut) -> Self { 68 let size = buffer.capacity(); 69 if size < INITIAL_CAPACITY { 70 buffer.reserve(INITIAL_CAPACITY - size); 71 } 72 73 Self { 74 buffer, 75 is_readable: size > 0, 76 eof: false, 77 has_errored: false, 78 } 79 } 80 } 81 82 impl From<BytesMut> for WriteFrame { from(mut buffer: BytesMut) -> Self83 fn from(mut buffer: BytesMut) -> Self { 84 let size = buffer.capacity(); 85 if size < INITIAL_CAPACITY { 86 buffer.reserve(INITIAL_CAPACITY - size); 87 } 88 89 Self { 90 buffer, 91 backpressure_boundary: INITIAL_CAPACITY, 92 } 93 } 94 } 95 96 impl Borrow<ReadFrame> for RWFrames { borrow(&self) -> &ReadFrame97 fn borrow(&self) -> &ReadFrame { 98 &self.read 99 } 100 } 101 impl BorrowMut<ReadFrame> for RWFrames { borrow_mut(&mut self) -> &mut ReadFrame102 fn borrow_mut(&mut self) -> &mut ReadFrame { 103 &mut self.read 104 } 105 } 106 impl Borrow<WriteFrame> for RWFrames { borrow(&self) -> &WriteFrame107 fn borrow(&self) -> &WriteFrame { 108 &self.write 109 } 110 } 111 impl BorrowMut<WriteFrame> for RWFrames { borrow_mut(&mut self) -> &mut WriteFrame112 fn borrow_mut(&mut self) -> &mut WriteFrame { 113 &mut self.write 114 } 115 } 116 impl<T, U, R> Stream for FramedImpl<T, U, R> 117 where 118 T: AsyncRead, 119 U: Decoder, 120 R: BorrowMut<ReadFrame>, 121 { 122 type Item = Result<U::Item, U::Error>; 123 poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>124 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 125 use crate::util::poll_read_buf; 126 127 let mut pinned = self.project(); 128 let state: &mut ReadFrame = pinned.state.borrow_mut(); 129 // The following loops implements a state machine with each state corresponding 130 // to a combination of the `is_readable` and `eof` flags. States persist across 131 // loop entries and most state transitions occur with a return. 132 // 133 // The initial state is `reading`. 134 // 135 // | state | eof | is_readable | has_errored | 136 // |---------|-------|-------------|-------------| 137 // | reading | false | false | false | 138 // | framing | false | true | false | 139 // | pausing | true | true | false | 140 // | paused | true | false | false | 141 // | errored | <any> | <any> | true | 142 // `decode_eof` returns Err 143 // ┌────────────────────────────────────────────────────────┐ 144 // `decode_eof` returns │ │ 145 // `Ok(Some)` │ │ 146 // ┌─────┐ │ `decode_eof` returns After returning │ 147 // Read 0 bytes ├─────▼──┴┐ `Ok(None)` ┌────────┐ ◄───┐ `None` ┌───▼─────┐ 148 // ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐ └───────────┤ Errored │ 149 // │ └─────────┘ └─┬──▲───┘ │ └───▲───▲─┘ 150 // Pending read │ │ │ │ │ │ 151 // ┌──────┐ │ `decode` returns `Some` │ └─────┘ │ │ 152 // │ │ │ ┌──────┐ │ Pending │ │ 153 // │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐ read n>0 bytes │ read │ │ 154 // └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘ │ │ 155 // └──┬─▲────┘ └─────┬──┬┘ │ │ 156 // │ │ │ │ `decode` returns Err │ │ 157 // │ └───decode` returns `None`──┘ └───────────────────────────────────────────────────────┘ │ 158 // │ read returns Err │ 159 // └────────────────────────────────────────────────────────────────────────────────────────────┘ 160 loop { 161 // Return `None` if we have encountered an error from the underlying decoder 162 // See: https://github.com/tokio-rs/tokio/issues/3976 163 if state.has_errored { 164 // preparing has_errored -> paused 165 trace!("Returning None and setting paused"); 166 state.is_readable = false; 167 state.has_errored = false; 168 return Poll::Ready(None); 169 } 170 171 // Repeatedly call `decode` or `decode_eof` while the buffer is "readable", 172 // i.e. it _might_ contain data consumable as a frame or closing frame. 173 // Both signal that there is no such data by returning `None`. 174 // 175 // If `decode` couldn't read a frame and the upstream source has returned eof, 176 // `decode_eof` will attempt to decode the remaining bytes as closing frames. 177 // 178 // If the underlying AsyncRead is resumable, we may continue after an EOF, 179 // but must finish emitting all of it's associated `decode_eof` frames. 180 // Furthermore, we don't want to emit any `decode_eof` frames on retried 181 // reads after an EOF unless we've actually read more data. 182 if state.is_readable { 183 // pausing or framing 184 if state.eof { 185 // pausing 186 let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| { 187 trace!("Got an error, going to errored state"); 188 state.has_errored = true; 189 err 190 })?; 191 if frame.is_none() { 192 state.is_readable = false; // prepare pausing -> paused 193 } 194 // implicit pausing -> pausing or pausing -> paused 195 return Poll::Ready(frame.map(Ok)); 196 } 197 198 // framing 199 trace!("attempting to decode a frame"); 200 201 if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| { 202 trace!("Got an error, going to errored state"); 203 state.has_errored = true; 204 op 205 })? { 206 trace!("frame decoded from buffer"); 207 // implicit framing -> framing 208 return Poll::Ready(Some(Ok(frame))); 209 } 210 211 // framing -> reading 212 state.is_readable = false; 213 } 214 // reading or paused 215 // If we can't build a frame yet, try to read more data and try again. 216 // Make sure we've got room for at least one byte to read to ensure 217 // that we don't get a spurious 0 that looks like EOF. 218 state.buffer.reserve(1); 219 #[allow(clippy::blocks_in_conditions)] 220 let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err( 221 |err| { 222 trace!("Got an error, going to errored state"); 223 state.has_errored = true; 224 err 225 }, 226 )? { 227 Poll::Ready(ct) => ct, 228 // implicit reading -> reading or implicit paused -> paused 229 Poll::Pending => return Poll::Pending, 230 }; 231 if bytect == 0 { 232 if state.eof { 233 // We're already at an EOF, and since we've reached this path 234 // we're also not readable. This implies that we've already finished 235 // our `decode_eof` handling, so we can simply return `None`. 236 // implicit paused -> paused 237 return Poll::Ready(None); 238 } 239 // prepare reading -> paused 240 state.eof = true; 241 } else { 242 // prepare paused -> framing or noop reading -> framing 243 state.eof = false; 244 } 245 246 // paused -> framing or reading -> framing or reading -> pausing 247 state.is_readable = true; 248 } 249 } 250 } 251 252 impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W> 253 where 254 T: AsyncWrite, 255 U: Encoder<I>, 256 U::Error: From<io::Error>, 257 W: BorrowMut<WriteFrame>, 258 { 259 type Error = U::Error; 260 poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>261 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 262 if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary { 263 self.as_mut().poll_flush(cx) 264 } else { 265 Poll::Ready(Ok(())) 266 } 267 } 268 start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error>269 fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { 270 let pinned = self.project(); 271 pinned 272 .codec 273 .encode(item, &mut pinned.state.borrow_mut().buffer)?; 274 Ok(()) 275 } 276 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>277 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 278 use crate::util::poll_write_buf; 279 trace!("flushing framed transport"); 280 let mut pinned = self.project(); 281 282 while !pinned.state.borrow_mut().buffer.is_empty() { 283 let WriteFrame { buffer, .. } = pinned.state.borrow_mut(); 284 trace!(remaining = buffer.len(), "writing;"); 285 286 let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; 287 288 if n == 0 { 289 return Poll::Ready(Err(io::Error::new( 290 io::ErrorKind::WriteZero, 291 "failed to \ 292 write frame to transport", 293 ) 294 .into())); 295 } 296 } 297 298 // Try flushing the underlying IO 299 ready!(pinned.inner.poll_flush(cx))?; 300 301 trace!("framed transport flushed"); 302 Poll::Ready(Ok(())) 303 } 304 poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>305 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 306 ready!(self.as_mut().poll_flush(cx))?; 307 ready!(self.project().inner.poll_shutdown(cx))?; 308 309 Poll::Ready(Ok(())) 310 } 311 } 312