1 use crate::codec::encoder::Encoder;
2 use crate::codec::framed_impl::{FramedImpl, WriteFrame};
3 
4 use futures_core::Stream;
5 use tokio::io::AsyncWrite;
6 
7 use bytes::BytesMut;
8 use futures_sink::Sink;
9 use pin_project_lite::pin_project;
10 use std::fmt;
11 use std::io;
12 use std::pin::Pin;
13 use std::task::{Context, Poll};
14 
15 pin_project! {
16     /// A [`Sink`] of frames encoded to an `AsyncWrite`.
17     ///
18     /// For examples of how to use `FramedWrite` with a codec, see the
19     /// examples on the [`codec`] module.
20     ///
21     /// # Cancellation safety
22     ///
23     /// * [`futures_util::sink::SinkExt::send`]: if send is used as the event in a
24     /// `tokio::select!` statement and some other branch completes first, then it is
25     /// guaranteed that the message was not sent, but the message itself is lost.
26     ///
27     /// [`Sink`]: futures_sink::Sink
28     /// [`codec`]: crate::codec
29     /// [`futures_util::sink::SinkExt::send`]: futures_util::sink::SinkExt::send
30     pub struct FramedWrite<T, E> {
31         #[pin]
32         inner: FramedImpl<T, E, WriteFrame>,
33     }
34 }
35 
36 impl<T, E> FramedWrite<T, E>
37 where
38     T: AsyncWrite,
39 {
40     /// Creates a new `FramedWrite` with the given `encoder`.
new(inner: T, encoder: E) -> FramedWrite<T, E>41     pub fn new(inner: T, encoder: E) -> FramedWrite<T, E> {
42         FramedWrite {
43             inner: FramedImpl {
44                 inner,
45                 codec: encoder,
46                 state: WriteFrame::default(),
47             },
48         }
49     }
50 }
51 
52 impl<T, E> FramedWrite<T, E> {
53     /// Returns a reference to the underlying I/O stream wrapped by
54     /// `FramedWrite`.
55     ///
56     /// Note that care should be taken to not tamper with the underlying stream
57     /// of data coming in as it may corrupt the stream of frames otherwise
58     /// being worked with.
get_ref(&self) -> &T59     pub fn get_ref(&self) -> &T {
60         &self.inner.inner
61     }
62 
63     /// Returns a mutable reference to the underlying I/O stream wrapped by
64     /// `FramedWrite`.
65     ///
66     /// Note that care should be taken to not tamper with the underlying stream
67     /// of data coming in as it may corrupt the stream of frames otherwise
68     /// being worked with.
get_mut(&mut self) -> &mut T69     pub fn get_mut(&mut self) -> &mut T {
70         &mut self.inner.inner
71     }
72 
73     /// Returns a pinned mutable reference to the underlying I/O stream wrapped by
74     /// `FramedWrite`.
75     ///
76     /// Note that care should be taken to not tamper with the underlying stream
77     /// of data coming in as it may corrupt the stream of frames otherwise
78     /// being worked with.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T>79     pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
80         self.project().inner.project().inner
81     }
82 
83     /// Consumes the `FramedWrite`, returning its underlying I/O stream.
84     ///
85     /// Note that care should be taken to not tamper with the underlying stream
86     /// of data coming in as it may corrupt the stream of frames otherwise
87     /// being worked with.
into_inner(self) -> T88     pub fn into_inner(self) -> T {
89         self.inner.inner
90     }
91 
92     /// Returns a reference to the underlying encoder.
encoder(&self) -> &E93     pub fn encoder(&self) -> &E {
94         &self.inner.codec
95     }
96 
97     /// Returns a mutable reference to the underlying encoder.
encoder_mut(&mut self) -> &mut E98     pub fn encoder_mut(&mut self) -> &mut E {
99         &mut self.inner.codec
100     }
101 
102     /// Maps the encoder `E` to `C`, preserving the write buffer
103     /// wrapped by `Framed`.
map_encoder<C, F>(self, map: F) -> FramedWrite<T, C> where F: FnOnce(E) -> C,104     pub fn map_encoder<C, F>(self, map: F) -> FramedWrite<T, C>
105     where
106         F: FnOnce(E) -> C,
107     {
108         // This could be potentially simplified once rust-lang/rust#86555 hits stable
109         let FramedImpl {
110             inner,
111             state,
112             codec,
113         } = self.inner;
114         FramedWrite {
115             inner: FramedImpl {
116                 inner,
117                 state,
118                 codec: map(codec),
119             },
120         }
121     }
122 
123     /// Returns a mutable reference to the underlying encoder.
encoder_pin_mut(self: Pin<&mut Self>) -> &mut E124     pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E {
125         self.project().inner.project().codec
126     }
127 
128     /// Returns a reference to the write buffer.
write_buffer(&self) -> &BytesMut129     pub fn write_buffer(&self) -> &BytesMut {
130         &self.inner.state.buffer
131     }
132 
133     /// Returns a mutable reference to the write buffer.
write_buffer_mut(&mut self) -> &mut BytesMut134     pub fn write_buffer_mut(&mut self) -> &mut BytesMut {
135         &mut self.inner.state.buffer
136     }
137 
138     /// Returns backpressure boundary
backpressure_boundary(&self) -> usize139     pub fn backpressure_boundary(&self) -> usize {
140         self.inner.state.backpressure_boundary
141     }
142 
143     /// Updates backpressure boundary
set_backpressure_boundary(&mut self, boundary: usize)144     pub fn set_backpressure_boundary(&mut self, boundary: usize) {
145         self.inner.state.backpressure_boundary = boundary;
146     }
147 }
148 
149 // This impl just defers to the underlying FramedImpl
150 impl<T, I, E> Sink<I> for FramedWrite<T, E>
151 where
152     T: AsyncWrite,
153     E: Encoder<I>,
154     E::Error: From<io::Error>,
155 {
156     type Error = E::Error;
157 
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>158     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
159         self.project().inner.poll_ready(cx)
160     }
161 
start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error>162     fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
163         self.project().inner.start_send(item)
164     }
165 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>166     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
167         self.project().inner.poll_flush(cx)
168     }
169 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>170     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
171         self.project().inner.poll_close(cx)
172     }
173 }
174 
175 // This impl just defers to the underlying T: Stream
176 impl<T, D> Stream for FramedWrite<T, D>
177 where
178     T: Stream,
179 {
180     type Item = T::Item;
181 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>182     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183         self.project().inner.project().inner.poll_next(cx)
184     }
185 }
186 
187 impl<T, U> fmt::Debug for FramedWrite<T, U>
188 where
189     T: fmt::Debug,
190     U: fmt::Debug,
191 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result192     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193         f.debug_struct("FramedWrite")
194             .field("inner", &self.get_ref())
195             .field("encoder", &self.encoder())
196             .field("buffer", &self.inner.state.buffer)
197             .finish()
198     }
199 }
200