1 use crate::codec::decoder::Decoder;
2 use crate::codec::encoder::Encoder;
3 
4 use futures_core::Stream;
5 use tokio::io::{AsyncRead, AsyncWrite};
6 
7 use bytes::BytesMut;
8 use futures_sink::Sink;
9 use pin_project_lite::pin_project;
10 use std::borrow::{Borrow, BorrowMut};
11 use std::io;
12 use std::pin::Pin;
13 use std::task::{ready, Context, Poll};
14 
15 pin_project! {
16     #[derive(Debug)]
17     pub(crate) struct FramedImpl<T, U, State> {
18         #[pin]
19         pub(crate) inner: T,
20         pub(crate) state: State,
21         pub(crate) codec: U,
22     }
23 }
24 
25 const INITIAL_CAPACITY: usize = 8 * 1024;
26 
27 #[derive(Debug)]
28 pub(crate) struct ReadFrame {
29     pub(crate) eof: bool,
30     pub(crate) is_readable: bool,
31     pub(crate) buffer: BytesMut,
32     pub(crate) has_errored: bool,
33 }
34 
35 pub(crate) struct WriteFrame {
36     pub(crate) buffer: BytesMut,
37     pub(crate) backpressure_boundary: usize,
38 }
39 
40 #[derive(Default)]
41 pub(crate) struct RWFrames {
42     pub(crate) read: ReadFrame,
43     pub(crate) write: WriteFrame,
44 }
45 
46 impl Default for ReadFrame {
default() -> Self47     fn default() -> Self {
48         Self {
49             eof: false,
50             is_readable: false,
51             buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
52             has_errored: false,
53         }
54     }
55 }
56 
57 impl Default for WriteFrame {
default() -> Self58     fn default() -> Self {
59         Self {
60             buffer: BytesMut::with_capacity(INITIAL_CAPACITY),
61             backpressure_boundary: INITIAL_CAPACITY,
62         }
63     }
64 }
65 
66 impl From<BytesMut> for ReadFrame {
from(mut buffer: BytesMut) -> Self67     fn from(mut buffer: BytesMut) -> Self {
68         let size = buffer.capacity();
69         if size < INITIAL_CAPACITY {
70             buffer.reserve(INITIAL_CAPACITY - size);
71         }
72 
73         Self {
74             buffer,
75             is_readable: size > 0,
76             eof: false,
77             has_errored: false,
78         }
79     }
80 }
81 
82 impl From<BytesMut> for WriteFrame {
from(mut buffer: BytesMut) -> Self83     fn from(mut buffer: BytesMut) -> Self {
84         let size = buffer.capacity();
85         if size < INITIAL_CAPACITY {
86             buffer.reserve(INITIAL_CAPACITY - size);
87         }
88 
89         Self {
90             buffer,
91             backpressure_boundary: INITIAL_CAPACITY,
92         }
93     }
94 }
95 
96 impl Borrow<ReadFrame> for RWFrames {
borrow(&self) -> &ReadFrame97     fn borrow(&self) -> &ReadFrame {
98         &self.read
99     }
100 }
101 impl BorrowMut<ReadFrame> for RWFrames {
borrow_mut(&mut self) -> &mut ReadFrame102     fn borrow_mut(&mut self) -> &mut ReadFrame {
103         &mut self.read
104     }
105 }
106 impl Borrow<WriteFrame> for RWFrames {
borrow(&self) -> &WriteFrame107     fn borrow(&self) -> &WriteFrame {
108         &self.write
109     }
110 }
111 impl BorrowMut<WriteFrame> for RWFrames {
borrow_mut(&mut self) -> &mut WriteFrame112     fn borrow_mut(&mut self) -> &mut WriteFrame {
113         &mut self.write
114     }
115 }
116 impl<T, U, R> Stream for FramedImpl<T, U, R>
117 where
118     T: AsyncRead,
119     U: Decoder,
120     R: BorrowMut<ReadFrame>,
121 {
122     type Item = Result<U::Item, U::Error>;
123 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>124     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
125         use crate::util::poll_read_buf;
126 
127         let mut pinned = self.project();
128         let state: &mut ReadFrame = pinned.state.borrow_mut();
129         // The following loops implements a state machine with each state corresponding
130         // to a combination of the `is_readable` and `eof` flags. States persist across
131         // loop entries and most state transitions occur with a return.
132         //
133         // The initial state is `reading`.
134         //
135         // | state   | eof   | is_readable | has_errored |
136         // |---------|-------|-------------|-------------|
137         // | reading | false | false       | false       |
138         // | framing | false | true        | false       |
139         // | pausing | true  | true        | false       |
140         // | paused  | true  | false       | false       |
141         // | errored | <any> | <any>       | true        |
142         //                                                       `decode_eof` returns Err
143         //                                          ┌────────────────────────────────────────────────────────┐
144         //                   `decode_eof` returns   │                                                        │
145         //                             `Ok(Some)`   │                                                        │
146         //                                 ┌─────┐  │     `decode_eof` returns               After returning │
147         //                Read 0 bytes     ├─────▼──┴┐    `Ok(None)`          ┌────────┐ ◄───┐ `None`    ┌───▼─────┐
148         //               ┌────────────────►│ Pausing ├───────────────────────►│ Paused ├─┐   └───────────┤ Errored │
149         //               │                 └─────────┘                        └─┬──▲───┘ │               └───▲───▲─┘
150         // Pending read  │                                                      │  │     │                   │   │
151         //     ┌──────┐  │            `decode` returns `Some`                   │  └─────┘                   │   │
152         //     │      │  │                   ┌──────┐                           │  Pending                   │   │
153         //     │ ┌────▼──┴─┐ Read n>0 bytes ┌┴──────▼─┐     read n>0 bytes      │  read                      │   │
154         //     └─┤ Reading ├───────────────►│ Framing │◄────────────────────────┘                            │   │
155         //       └──┬─▲────┘                └─────┬──┬┘                                                      │   │
156         //          │ │                           │  │                 `decode` returns Err                  │   │
157         //          │ └───decode` returns `None`──┘  └───────────────────────────────────────────────────────┘   │
158         //          │                             read returns Err                                               │
159         //          └────────────────────────────────────────────────────────────────────────────────────────────┘
160         loop {
161             // Return `None` if we have encountered an error from the underlying decoder
162             // See: https://github.com/tokio-rs/tokio/issues/3976
163             if state.has_errored {
164                 // preparing has_errored -> paused
165                 trace!("Returning None and setting paused");
166                 state.is_readable = false;
167                 state.has_errored = false;
168                 return Poll::Ready(None);
169             }
170 
171             // Repeatedly call `decode` or `decode_eof` while the buffer is "readable",
172             // i.e. it _might_ contain data consumable as a frame or closing frame.
173             // Both signal that there is no such data by returning `None`.
174             //
175             // If `decode` couldn't read a frame and the upstream source has returned eof,
176             // `decode_eof` will attempt to decode the remaining bytes as closing frames.
177             //
178             // If the underlying AsyncRead is resumable, we may continue after an EOF,
179             // but must finish emitting all of it's associated `decode_eof` frames.
180             // Furthermore, we don't want to emit any `decode_eof` frames on retried
181             // reads after an EOF unless we've actually read more data.
182             if state.is_readable {
183                 // pausing or framing
184                 if state.eof {
185                     // pausing
186                     let frame = pinned.codec.decode_eof(&mut state.buffer).map_err(|err| {
187                         trace!("Got an error, going to errored state");
188                         state.has_errored = true;
189                         err
190                     })?;
191                     if frame.is_none() {
192                         state.is_readable = false; // prepare pausing -> paused
193                     }
194                     // implicit pausing -> pausing or pausing -> paused
195                     return Poll::Ready(frame.map(Ok));
196                 }
197 
198                 // framing
199                 trace!("attempting to decode a frame");
200 
201                 if let Some(frame) = pinned.codec.decode(&mut state.buffer).map_err(|op| {
202                     trace!("Got an error, going to errored state");
203                     state.has_errored = true;
204                     op
205                 })? {
206                     trace!("frame decoded from buffer");
207                     // implicit framing -> framing
208                     return Poll::Ready(Some(Ok(frame)));
209                 }
210 
211                 // framing -> reading
212                 state.is_readable = false;
213             }
214             // reading or paused
215             // If we can't build a frame yet, try to read more data and try again.
216             // Make sure we've got room for at least one byte to read to ensure
217             // that we don't get a spurious 0 that looks like EOF.
218             state.buffer.reserve(1);
219             #[allow(clippy::blocks_in_conditions)]
220             let bytect = match poll_read_buf(pinned.inner.as_mut(), cx, &mut state.buffer).map_err(
221                 |err| {
222                     trace!("Got an error, going to errored state");
223                     state.has_errored = true;
224                     err
225                 },
226             )? {
227                 Poll::Ready(ct) => ct,
228                 // implicit reading -> reading or implicit paused -> paused
229                 Poll::Pending => return Poll::Pending,
230             };
231             if bytect == 0 {
232                 if state.eof {
233                     // We're already at an EOF, and since we've reached this path
234                     // we're also not readable. This implies that we've already finished
235                     // our `decode_eof` handling, so we can simply return `None`.
236                     // implicit paused -> paused
237                     return Poll::Ready(None);
238                 }
239                 // prepare reading -> paused
240                 state.eof = true;
241             } else {
242                 // prepare paused -> framing or noop reading -> framing
243                 state.eof = false;
244             }
245 
246             // paused -> framing or reading -> framing or reading -> pausing
247             state.is_readable = true;
248         }
249     }
250 }
251 
252 impl<T, I, U, W> Sink<I> for FramedImpl<T, U, W>
253 where
254     T: AsyncWrite,
255     U: Encoder<I>,
256     U::Error: From<io::Error>,
257     W: BorrowMut<WriteFrame>,
258 {
259     type Error = U::Error;
260 
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>261     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262         if self.state.borrow().buffer.len() >= self.state.borrow().backpressure_boundary {
263             self.as_mut().poll_flush(cx)
264         } else {
265             Poll::Ready(Ok(()))
266         }
267     }
268 
start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error>269     fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
270         let pinned = self.project();
271         pinned
272             .codec
273             .encode(item, &mut pinned.state.borrow_mut().buffer)?;
274         Ok(())
275     }
276 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>277     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
278         use crate::util::poll_write_buf;
279         trace!("flushing framed transport");
280         let mut pinned = self.project();
281 
282         while !pinned.state.borrow_mut().buffer.is_empty() {
283             let WriteFrame { buffer, .. } = pinned.state.borrow_mut();
284             trace!(remaining = buffer.len(), "writing;");
285 
286             let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?;
287 
288             if n == 0 {
289                 return Poll::Ready(Err(io::Error::new(
290                     io::ErrorKind::WriteZero,
291                     "failed to \
292                      write frame to transport",
293                 )
294                 .into()));
295             }
296         }
297 
298         // Try flushing the underlying IO
299         ready!(pinned.inner.poll_flush(cx))?;
300 
301         trace!("framed transport flushed");
302         Poll::Ready(Ok(()))
303     }
304 
poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>305     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
306         ready!(self.as_mut().poll_flush(cx))?;
307         ready!(self.project().inner.poll_shutdown(cx))?;
308 
309         Poll::Ready(Ok(()))
310     }
311 }
312