1 mod error;
2 mod framed_read;
3 mod framed_write;
4 
5 pub use self::error::{SendError, UserError};
6 
7 use self::framed_read::FramedRead;
8 use self::framed_write::FramedWrite;
9 
10 use crate::frame::{self, Data, Frame};
11 use crate::proto::Error;
12 
13 use bytes::Buf;
14 use futures_core::Stream;
15 use futures_sink::Sink;
16 use std::pin::Pin;
17 use std::task::{Context, Poll};
18 use tokio::io::{AsyncRead, AsyncWrite};
19 use tokio_util::codec::length_delimited;
20 
21 use std::io;
22 
23 #[derive(Debug)]
24 pub struct Codec<T, B> {
25     inner: FramedRead<FramedWrite<T, B>>,
26 }
27 
28 impl<T, B> Codec<T, B>
29 where
30     T: AsyncRead + AsyncWrite + Unpin,
31     B: Buf,
32 {
33     /// Returns a new `Codec` with the default max frame size
34     #[inline]
new(io: T) -> Self35     pub fn new(io: T) -> Self {
36         Self::with_max_recv_frame_size(io, frame::DEFAULT_MAX_FRAME_SIZE as usize)
37     }
38 
39     /// Returns a new `Codec` with the given maximum frame size
with_max_recv_frame_size(io: T, max_frame_size: usize) -> Self40     pub fn with_max_recv_frame_size(io: T, max_frame_size: usize) -> Self {
41         // Wrap with writer
42         let framed_write = FramedWrite::new(io);
43 
44         // Delimit the frames
45         let delimited = length_delimited::Builder::new()
46             .big_endian()
47             .length_field_length(3)
48             .length_adjustment(9)
49             .num_skip(0) // Don't skip the header
50             .new_read(framed_write);
51 
52         let mut inner = FramedRead::new(delimited);
53 
54         // Use FramedRead's method since it checks the value is within range.
55         inner.set_max_frame_size(max_frame_size);
56 
57         Codec { inner }
58     }
59 }
60 
61 impl<T, B> Codec<T, B> {
62     /// Updates the max received frame size.
63     ///
64     /// The change takes effect the next time a frame is decoded. In other
65     /// words, if a frame is currently in process of being decoded with a frame
66     /// size greater than `val` but less than the max frame size in effect
67     /// before calling this function, then the frame will be allowed.
68     #[inline]
set_max_recv_frame_size(&mut self, val: usize)69     pub fn set_max_recv_frame_size(&mut self, val: usize) {
70         self.inner.set_max_frame_size(val)
71     }
72 
73     /// Returns the current max received frame size setting.
74     ///
75     /// This is the largest size this codec will accept from the wire. Larger
76     /// frames will be rejected.
77     #[cfg(feature = "unstable")]
78     #[inline]
max_recv_frame_size(&self) -> usize79     pub fn max_recv_frame_size(&self) -> usize {
80         self.inner.max_frame_size()
81     }
82 
83     /// Returns the max frame size that can be sent to the peer.
max_send_frame_size(&self) -> usize84     pub fn max_send_frame_size(&self) -> usize {
85         self.inner.get_ref().max_frame_size()
86     }
87 
88     /// Set the peer's max frame size.
set_max_send_frame_size(&mut self, val: usize)89     pub fn set_max_send_frame_size(&mut self, val: usize) {
90         self.framed_write().set_max_frame_size(val)
91     }
92 
93     /// Set the peer's header table size size.
set_send_header_table_size(&mut self, val: usize)94     pub fn set_send_header_table_size(&mut self, val: usize) {
95         self.framed_write().set_header_table_size(val)
96     }
97 
98     /// Set the decoder header table size size.
set_recv_header_table_size(&mut self, val: usize)99     pub fn set_recv_header_table_size(&mut self, val: usize) {
100         self.inner.set_header_table_size(val)
101     }
102 
103     /// Set the max header list size that can be received.
set_max_recv_header_list_size(&mut self, val: usize)104     pub fn set_max_recv_header_list_size(&mut self, val: usize) {
105         self.inner.set_max_header_list_size(val);
106     }
107 
108     /// Get a reference to the inner stream.
109     #[cfg(feature = "unstable")]
get_ref(&self) -> &T110     pub fn get_ref(&self) -> &T {
111         self.inner.get_ref().get_ref()
112     }
113 
114     /// Get a mutable reference to the inner stream.
get_mut(&mut self) -> &mut T115     pub fn get_mut(&mut self) -> &mut T {
116         self.inner.get_mut().get_mut()
117     }
118 
119     /// Takes the data payload value that was fully written to the socket
take_last_data_frame(&mut self) -> Option<Data<B>>120     pub(crate) fn take_last_data_frame(&mut self) -> Option<Data<B>> {
121         self.framed_write().take_last_data_frame()
122     }
123 
framed_write(&mut self) -> &mut FramedWrite<T, B>124     fn framed_write(&mut self) -> &mut FramedWrite<T, B> {
125         self.inner.get_mut()
126     }
127 }
128 
129 impl<T, B> Codec<T, B>
130 where
131     T: AsyncWrite + Unpin,
132     B: Buf,
133 {
134     /// Returns `Ready` when the codec can buffer a frame
poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>>135     pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
136         self.framed_write().poll_ready(cx)
137     }
138 
139     /// Buffer a frame.
140     ///
141     /// `poll_ready` must be called first to ensure that a frame may be
142     /// accepted.
143     ///
144     /// TODO: Rename this to avoid conflicts with Sink::buffer
buffer(&mut self, item: Frame<B>) -> Result<(), UserError>145     pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
146         self.framed_write().buffer(item)
147     }
148 
149     /// Flush buffered data to the wire
flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>>150     pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
151         self.framed_write().flush(cx)
152     }
153 
154     /// Shutdown the send half
shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>>155     pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
156         self.framed_write().shutdown(cx)
157     }
158 }
159 
160 impl<T, B> Stream for Codec<T, B>
161 where
162     T: AsyncRead + Unpin,
163 {
164     type Item = Result<Frame, Error>;
165 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>166     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
167         Pin::new(&mut self.inner).poll_next(cx)
168     }
169 }
170 
171 impl<T, B> Sink<Frame<B>> for Codec<T, B>
172 where
173     T: AsyncWrite + Unpin,
174     B: Buf,
175 {
176     type Error = SendError;
177 
start_send(mut self: Pin<&mut Self>, item: Frame<B>) -> Result<(), Self::Error>178     fn start_send(mut self: Pin<&mut Self>, item: Frame<B>) -> Result<(), Self::Error> {
179         Codec::buffer(&mut self, item)?;
180         Ok(())
181     }
182     /// Returns `Ready` when the codec can buffer a frame
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>183     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184         self.framed_write().poll_ready(cx).map_err(Into::into)
185     }
186 
187     /// Flush buffered data to the wire
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>188     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189         self.framed_write().flush(cx).map_err(Into::into)
190     }
191 
poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>192     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
193         ready!(self.shutdown(cx))?;
194         Poll::Ready(Ok(()))
195     }
196 }
197 
198 // TODO: remove (or improve) this
199 impl<T> From<T> for Codec<T, bytes::Bytes>
200 where
201     T: AsyncRead + AsyncWrite + Unpin,
202 {
from(src: T) -> Self203     fn from(src: T) -> Self {
204         Self::new(src)
205     }
206 }
207