1 use super::{
2     store, Buffer, Codec, Config, Counts, Frame, Prioritize, Prioritized, Store, Stream, StreamId,
3     StreamIdOverflow, WindowSize,
4 };
5 use crate::codec::UserError;
6 use crate::frame::{self, Reason};
7 use crate::proto::{self, Error, Initiator};
8 
9 use bytes::Buf;
10 use tokio::io::AsyncWrite;
11 
12 use std::cmp::Ordering;
13 use std::io;
14 use std::task::{Context, Poll, Waker};
15 
16 /// Manages state transitions related to outbound frames.
17 #[derive(Debug)]
18 pub(super) struct Send {
19     /// Stream identifier to use for next initialized stream.
20     next_stream_id: Result<StreamId, StreamIdOverflow>,
21 
22     /// Any streams with a higher ID are ignored.
23     ///
24     /// This starts as MAX, but is lowered when a GOAWAY is received.
25     ///
26     /// > After sending a GOAWAY frame, the sender can discard frames for
27     /// > streams initiated by the receiver with identifiers higher than
28     /// > the identified last stream.
29     max_stream_id: StreamId,
30 
31     /// Initial window size of locally initiated streams
32     init_window_sz: WindowSize,
33 
34     /// Prioritization layer
35     prioritize: Prioritize,
36 
37     is_push_enabled: bool,
38 
39     /// If extended connect protocol is enabled.
40     is_extended_connect_protocol_enabled: bool,
41 }
42 
43 /// A value to detect which public API has called `poll_reset`.
44 #[derive(Debug)]
45 pub(crate) enum PollReset {
46     AwaitingHeaders,
47     Streaming,
48 }
49 
50 impl Send {
51     /// Create a new `Send`
new(config: &Config) -> Self52     pub fn new(config: &Config) -> Self {
53         Send {
54             init_window_sz: config.remote_init_window_sz,
55             max_stream_id: StreamId::MAX,
56             next_stream_id: Ok(config.local_next_stream_id),
57             prioritize: Prioritize::new(config),
58             is_push_enabled: true,
59             is_extended_connect_protocol_enabled: false,
60         }
61     }
62 
63     /// Returns the initial send window size
init_window_sz(&self) -> WindowSize64     pub fn init_window_sz(&self) -> WindowSize {
65         self.init_window_sz
66     }
67 
open(&mut self) -> Result<StreamId, UserError>68     pub fn open(&mut self) -> Result<StreamId, UserError> {
69         let stream_id = self.ensure_next_stream_id()?;
70         self.next_stream_id = stream_id.next_id();
71         Ok(stream_id)
72     }
73 
reserve_local(&mut self) -> Result<StreamId, UserError>74     pub fn reserve_local(&mut self) -> Result<StreamId, UserError> {
75         let stream_id = self.ensure_next_stream_id()?;
76         self.next_stream_id = stream_id.next_id();
77         Ok(stream_id)
78     }
79 
check_headers(fields: &http::HeaderMap) -> Result<(), UserError>80     fn check_headers(fields: &http::HeaderMap) -> Result<(), UserError> {
81         // 8.1.2.2. Connection-Specific Header Fields
82         if fields.contains_key(http::header::CONNECTION)
83             || fields.contains_key(http::header::TRANSFER_ENCODING)
84             || fields.contains_key(http::header::UPGRADE)
85             || fields.contains_key("keep-alive")
86             || fields.contains_key("proxy-connection")
87         {
88             tracing::debug!("illegal connection-specific headers found");
89             return Err(UserError::MalformedHeaders);
90         } else if let Some(te) = fields.get(http::header::TE) {
91             if te != "trailers" {
92                 tracing::debug!("illegal connection-specific headers found");
93                 return Err(UserError::MalformedHeaders);
94             }
95         }
96         Ok(())
97     }
98 
send_push_promise<B>( &mut self, frame: frame::PushPromise, buffer: &mut Buffer<Frame<B>>, stream: &mut store::Ptr, task: &mut Option<Waker>, ) -> Result<(), UserError>99     pub fn send_push_promise<B>(
100         &mut self,
101         frame: frame::PushPromise,
102         buffer: &mut Buffer<Frame<B>>,
103         stream: &mut store::Ptr,
104         task: &mut Option<Waker>,
105     ) -> Result<(), UserError> {
106         if !self.is_push_enabled {
107             return Err(UserError::PeerDisabledServerPush);
108         }
109 
110         tracing::trace!(
111             "send_push_promise; frame={:?}; init_window={:?}",
112             frame,
113             self.init_window_sz
114         );
115 
116         Self::check_headers(frame.fields())?;
117 
118         // Queue the frame for sending
119         self.prioritize
120             .queue_frame(frame.into(), buffer, stream, task);
121 
122         Ok(())
123     }
124 
send_headers<B>( &mut self, frame: frame::Headers, buffer: &mut Buffer<Frame<B>>, stream: &mut store::Ptr, counts: &mut Counts, task: &mut Option<Waker>, ) -> Result<(), UserError>125     pub fn send_headers<B>(
126         &mut self,
127         frame: frame::Headers,
128         buffer: &mut Buffer<Frame<B>>,
129         stream: &mut store::Ptr,
130         counts: &mut Counts,
131         task: &mut Option<Waker>,
132     ) -> Result<(), UserError> {
133         tracing::trace!(
134             "send_headers; frame={:?}; init_window={:?}",
135             frame,
136             self.init_window_sz
137         );
138 
139         Self::check_headers(frame.fields())?;
140 
141         let end_stream = frame.is_end_stream();
142 
143         // Update the state
144         stream.state.send_open(end_stream)?;
145 
146         let mut pending_open = false;
147         if counts.peer().is_local_init(frame.stream_id()) && !stream.is_pending_push {
148             self.prioritize.queue_open(stream);
149             pending_open = true;
150         }
151 
152         // Queue the frame for sending
153         //
154         // This call expects that, since new streams are in the open queue, new
155         // streams won't be pushed on pending_send.
156         self.prioritize
157             .queue_frame(frame.into(), buffer, stream, task);
158 
159         // Need to notify the connection when pushing onto pending_open since
160         // queue_frame only notifies for pending_send.
161         if pending_open {
162             if let Some(task) = task.take() {
163                 task.wake();
164             }
165         }
166 
167         Ok(())
168     }
169 
170     /// Send an explicit RST_STREAM frame
send_reset<B>( &mut self, reason: Reason, initiator: Initiator, buffer: &mut Buffer<Frame<B>>, stream: &mut store::Ptr, counts: &mut Counts, task: &mut Option<Waker>, )171     pub fn send_reset<B>(
172         &mut self,
173         reason: Reason,
174         initiator: Initiator,
175         buffer: &mut Buffer<Frame<B>>,
176         stream: &mut store::Ptr,
177         counts: &mut Counts,
178         task: &mut Option<Waker>,
179     ) {
180         let is_reset = stream.state.is_reset();
181         let is_closed = stream.state.is_closed();
182         let is_empty = stream.pending_send.is_empty();
183         let stream_id = stream.id;
184 
185         tracing::trace!(
186             "send_reset(..., reason={:?}, initiator={:?}, stream={:?}, ..., \
187              is_reset={:?}; is_closed={:?}; pending_send.is_empty={:?}; \
188              state={:?} \
189              ",
190             reason,
191             initiator,
192             stream_id,
193             is_reset,
194             is_closed,
195             is_empty,
196             stream.state
197         );
198 
199         if is_reset {
200             // Don't double reset
201             tracing::trace!(
202                 " -> not sending RST_STREAM ({:?} is already reset)",
203                 stream_id
204             );
205             return;
206         }
207 
208         // Transition the state to reset no matter what.
209         stream.state.set_reset(stream_id, reason, initiator);
210 
211         // If closed AND the send queue is flushed, then the stream cannot be
212         // reset explicitly, either. Implicit resets can still be queued.
213         if is_closed && is_empty {
214             tracing::trace!(
215                 " -> not sending explicit RST_STREAM ({:?} was closed \
216                  and send queue was flushed)",
217                 stream_id
218             );
219             return;
220         }
221 
222         // Clear all pending outbound frames.
223         // Note that we don't call `self.recv_err` because we want to enqueue
224         // the reset frame before transitioning the stream inside
225         // `reclaim_all_capacity`.
226         self.prioritize.clear_queue(buffer, stream);
227 
228         let frame = frame::Reset::new(stream.id, reason);
229 
230         tracing::trace!("send_reset -- queueing; frame={:?}", frame);
231         self.prioritize
232             .queue_frame(frame.into(), buffer, stream, task);
233         self.prioritize.reclaim_all_capacity(stream, counts);
234     }
235 
schedule_implicit_reset( &mut self, stream: &mut store::Ptr, reason: Reason, counts: &mut Counts, task: &mut Option<Waker>, )236     pub fn schedule_implicit_reset(
237         &mut self,
238         stream: &mut store::Ptr,
239         reason: Reason,
240         counts: &mut Counts,
241         task: &mut Option<Waker>,
242     ) {
243         if stream.state.is_closed() {
244             // Stream is already closed, nothing more to do
245             return;
246         }
247 
248         stream.state.set_scheduled_reset(reason);
249 
250         self.prioritize.reclaim_reserved_capacity(stream, counts);
251         self.prioritize.schedule_send(stream, task);
252     }
253 
send_data<B>( &mut self, frame: frame::Data<B>, buffer: &mut Buffer<Frame<B>>, stream: &mut store::Ptr, counts: &mut Counts, task: &mut Option<Waker>, ) -> Result<(), UserError> where B: Buf,254     pub fn send_data<B>(
255         &mut self,
256         frame: frame::Data<B>,
257         buffer: &mut Buffer<Frame<B>>,
258         stream: &mut store::Ptr,
259         counts: &mut Counts,
260         task: &mut Option<Waker>,
261     ) -> Result<(), UserError>
262     where
263         B: Buf,
264     {
265         self.prioritize
266             .send_data(frame, buffer, stream, counts, task)
267     }
268 
send_trailers<B>( &mut self, frame: frame::Headers, buffer: &mut Buffer<Frame<B>>, stream: &mut store::Ptr, counts: &mut Counts, task: &mut Option<Waker>, ) -> Result<(), UserError>269     pub fn send_trailers<B>(
270         &mut self,
271         frame: frame::Headers,
272         buffer: &mut Buffer<Frame<B>>,
273         stream: &mut store::Ptr,
274         counts: &mut Counts,
275         task: &mut Option<Waker>,
276     ) -> Result<(), UserError> {
277         // TODO: Should this logic be moved into state.rs?
278         if !stream.state.is_send_streaming() {
279             return Err(UserError::UnexpectedFrameType);
280         }
281 
282         stream.state.send_close();
283 
284         tracing::trace!("send_trailers -- queuing; frame={:?}", frame);
285         self.prioritize
286             .queue_frame(frame.into(), buffer, stream, task);
287 
288         // Release any excess capacity
289         self.prioritize.reserve_capacity(0, stream, counts);
290 
291         Ok(())
292     }
293 
poll_complete<T, B>( &mut self, cx: &mut Context, buffer: &mut Buffer<Frame<B>>, store: &mut Store, counts: &mut Counts, dst: &mut Codec<T, Prioritized<B>>, ) -> Poll<io::Result<()>> where T: AsyncWrite + Unpin, B: Buf,294     pub fn poll_complete<T, B>(
295         &mut self,
296         cx: &mut Context,
297         buffer: &mut Buffer<Frame<B>>,
298         store: &mut Store,
299         counts: &mut Counts,
300         dst: &mut Codec<T, Prioritized<B>>,
301     ) -> Poll<io::Result<()>>
302     where
303         T: AsyncWrite + Unpin,
304         B: Buf,
305     {
306         self.prioritize
307             .poll_complete(cx, buffer, store, counts, dst)
308     }
309 
310     /// Request capacity to send data
reserve_capacity( &mut self, capacity: WindowSize, stream: &mut store::Ptr, counts: &mut Counts, )311     pub fn reserve_capacity(
312         &mut self,
313         capacity: WindowSize,
314         stream: &mut store::Ptr,
315         counts: &mut Counts,
316     ) {
317         self.prioritize.reserve_capacity(capacity, stream, counts)
318     }
319 
poll_capacity( &mut self, cx: &Context, stream: &mut store::Ptr, ) -> Poll<Option<Result<WindowSize, UserError>>>320     pub fn poll_capacity(
321         &mut self,
322         cx: &Context,
323         stream: &mut store::Ptr,
324     ) -> Poll<Option<Result<WindowSize, UserError>>> {
325         if !stream.state.is_send_streaming() {
326             return Poll::Ready(None);
327         }
328 
329         if !stream.send_capacity_inc {
330             stream.wait_send(cx);
331             return Poll::Pending;
332         }
333 
334         stream.send_capacity_inc = false;
335 
336         Poll::Ready(Some(Ok(self.capacity(stream))))
337     }
338 
339     /// Current available stream send capacity
capacity(&self, stream: &mut store::Ptr) -> WindowSize340     pub fn capacity(&self, stream: &mut store::Ptr) -> WindowSize {
341         stream.capacity(self.prioritize.max_buffer_size())
342     }
343 
poll_reset( &self, cx: &Context, stream: &mut Stream, mode: PollReset, ) -> Poll<Result<Reason, crate::Error>>344     pub fn poll_reset(
345         &self,
346         cx: &Context,
347         stream: &mut Stream,
348         mode: PollReset,
349     ) -> Poll<Result<Reason, crate::Error>> {
350         match stream.state.ensure_reason(mode)? {
351             Some(reason) => Poll::Ready(Ok(reason)),
352             None => {
353                 stream.wait_send(cx);
354                 Poll::Pending
355             }
356         }
357     }
358 
recv_connection_window_update( &mut self, frame: frame::WindowUpdate, store: &mut Store, counts: &mut Counts, ) -> Result<(), Reason>359     pub fn recv_connection_window_update(
360         &mut self,
361         frame: frame::WindowUpdate,
362         store: &mut Store,
363         counts: &mut Counts,
364     ) -> Result<(), Reason> {
365         self.prioritize
366             .recv_connection_window_update(frame.size_increment(), store, counts)
367     }
368 
recv_stream_window_update<B>( &mut self, sz: WindowSize, buffer: &mut Buffer<Frame<B>>, stream: &mut store::Ptr, counts: &mut Counts, task: &mut Option<Waker>, ) -> Result<(), Reason>369     pub fn recv_stream_window_update<B>(
370         &mut self,
371         sz: WindowSize,
372         buffer: &mut Buffer<Frame<B>>,
373         stream: &mut store::Ptr,
374         counts: &mut Counts,
375         task: &mut Option<Waker>,
376     ) -> Result<(), Reason> {
377         if let Err(e) = self.prioritize.recv_stream_window_update(sz, stream) {
378             tracing::debug!("recv_stream_window_update !!; err={:?}", e);
379 
380             self.send_reset(
381                 Reason::FLOW_CONTROL_ERROR,
382                 Initiator::Library,
383                 buffer,
384                 stream,
385                 counts,
386                 task,
387             );
388 
389             return Err(e);
390         }
391 
392         Ok(())
393     }
394 
recv_go_away(&mut self, last_stream_id: StreamId) -> Result<(), Error>395     pub(super) fn recv_go_away(&mut self, last_stream_id: StreamId) -> Result<(), Error> {
396         if last_stream_id > self.max_stream_id {
397             // The remote endpoint sent a `GOAWAY` frame indicating a stream
398             // that we never sent, or that we have already terminated on account
399             // of previous `GOAWAY` frame. In either case, that is illegal.
400             // (When sending multiple `GOAWAY`s, "Endpoints MUST NOT increase
401             // the value they send in the last stream identifier, since the
402             // peers might already have retried unprocessed requests on another
403             // connection.")
404             proto_err!(conn:
405                 "recv_go_away: last_stream_id ({:?}) > max_stream_id ({:?})",
406                 last_stream_id, self.max_stream_id,
407             );
408             return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
409         }
410 
411         self.max_stream_id = last_stream_id;
412         Ok(())
413     }
414 
handle_error<B>( &mut self, buffer: &mut Buffer<Frame<B>>, stream: &mut store::Ptr, counts: &mut Counts, )415     pub fn handle_error<B>(
416         &mut self,
417         buffer: &mut Buffer<Frame<B>>,
418         stream: &mut store::Ptr,
419         counts: &mut Counts,
420     ) {
421         // Clear all pending outbound frames
422         self.prioritize.clear_queue(buffer, stream);
423         self.prioritize.reclaim_all_capacity(stream, counts);
424     }
425 
apply_remote_settings<B>( &mut self, settings: &frame::Settings, buffer: &mut Buffer<Frame<B>>, store: &mut Store, counts: &mut Counts, task: &mut Option<Waker>, ) -> Result<(), Error>426     pub fn apply_remote_settings<B>(
427         &mut self,
428         settings: &frame::Settings,
429         buffer: &mut Buffer<Frame<B>>,
430         store: &mut Store,
431         counts: &mut Counts,
432         task: &mut Option<Waker>,
433     ) -> Result<(), Error> {
434         if let Some(val) = settings.is_extended_connect_protocol_enabled() {
435             self.is_extended_connect_protocol_enabled = val;
436         }
437 
438         // Applies an update to the remote endpoint's initial window size.
439         //
440         // Per RFC 7540 §6.9.2:
441         //
442         // In addition to changing the flow-control window for streams that are
443         // not yet active, a SETTINGS frame can alter the initial flow-control
444         // window size for streams with active flow-control windows (that is,
445         // streams in the "open" or "half-closed (remote)" state). When the
446         // value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust
447         // the size of all stream flow-control windows that it maintains by the
448         // difference between the new value and the old value.
449         //
450         // A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available
451         // space in a flow-control window to become negative. A sender MUST
452         // track the negative flow-control window and MUST NOT send new
453         // flow-controlled frames until it receives WINDOW_UPDATE frames that
454         // cause the flow-control window to become positive.
455         if let Some(val) = settings.initial_window_size() {
456             let old_val = self.init_window_sz;
457             self.init_window_sz = val;
458 
459             match val.cmp(&old_val) {
460                 Ordering::Less => {
461                     // We must decrease the (remote) window on every open stream.
462                     let dec = old_val - val;
463                     tracing::trace!("decrementing all windows; dec={}", dec);
464 
465                     let mut total_reclaimed = 0;
466                     store.try_for_each(|mut stream| {
467                         let stream = &mut *stream;
468 
469                         tracing::trace!(
470                             "decrementing stream window; id={:?}; decr={}; flow={:?}",
471                             stream.id,
472                             dec,
473                             stream.send_flow
474                         );
475 
476                         // TODO: this decrement can underflow based on received frames!
477                         stream
478                             .send_flow
479                             .dec_send_window(dec)
480                             .map_err(proto::Error::library_go_away)?;
481 
482                         // It's possible that decreasing the window causes
483                         // `window_size` (the stream-specific window) to fall below
484                         // `available` (the portion of the connection-level window
485                         // that we have allocated to the stream).
486                         // In this case, we should take that excess allocation away
487                         // and reassign it to other streams.
488                         let window_size = stream.send_flow.window_size();
489                         let available = stream.send_flow.available().as_size();
490                         let reclaimed = if available > window_size {
491                             // Drop down to `window_size`.
492                             let reclaim = available - window_size;
493                             stream
494                                 .send_flow
495                                 .claim_capacity(reclaim)
496                                 .map_err(proto::Error::library_go_away)?;
497                             total_reclaimed += reclaim;
498                             reclaim
499                         } else {
500                             0
501                         };
502 
503                         tracing::trace!(
504                             "decremented stream window; id={:?}; decr={}; reclaimed={}; flow={:?}",
505                             stream.id,
506                             dec,
507                             reclaimed,
508                             stream.send_flow
509                         );
510 
511                         // TODO: Should this notify the producer when the capacity
512                         // of a stream is reduced? Maybe it should if the capacity
513                         // is reduced to zero, allowing the producer to stop work.
514 
515                         Ok::<_, proto::Error>(())
516                     })?;
517 
518                     self.prioritize
519                         .assign_connection_capacity(total_reclaimed, store, counts);
520                 }
521                 Ordering::Greater => {
522                     let inc = val - old_val;
523 
524                     store.try_for_each(|mut stream| {
525                         self.recv_stream_window_update(inc, buffer, &mut stream, counts, task)
526                             .map_err(Error::library_go_away)
527                     })?;
528                 }
529                 Ordering::Equal => (),
530             }
531         }
532 
533         if let Some(val) = settings.is_push_enabled() {
534             self.is_push_enabled = val
535         }
536 
537         Ok(())
538     }
539 
clear_queues(&mut self, store: &mut Store, counts: &mut Counts)540     pub fn clear_queues(&mut self, store: &mut Store, counts: &mut Counts) {
541         self.prioritize.clear_pending_capacity(store, counts);
542         self.prioritize.clear_pending_send(store, counts);
543         self.prioritize.clear_pending_open(store, counts);
544     }
545 
ensure_not_idle(&self, id: StreamId) -> Result<(), Reason>546     pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> {
547         if let Ok(next) = self.next_stream_id {
548             if id >= next {
549                 return Err(Reason::PROTOCOL_ERROR);
550             }
551         }
552         // if next_stream_id is overflowed, that's ok.
553 
554         Ok(())
555     }
556 
ensure_next_stream_id(&self) -> Result<StreamId, UserError>557     pub fn ensure_next_stream_id(&self) -> Result<StreamId, UserError> {
558         self.next_stream_id
559             .map_err(|_| UserError::OverflowedStreamId)
560     }
561 
may_have_created_stream(&self, id: StreamId) -> bool562     pub fn may_have_created_stream(&self, id: StreamId) -> bool {
563         if let Ok(next_id) = self.next_stream_id {
564             // Peer::is_local_init should have been called beforehand
565             debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated(),);
566             id < next_id
567         } else {
568             true
569         }
570     }
571 
maybe_reset_next_stream_id(&mut self, id: StreamId)572     pub(super) fn maybe_reset_next_stream_id(&mut self, id: StreamId) {
573         if let Ok(next_id) = self.next_stream_id {
574             // Peer::is_local_init should have been called beforehand
575             debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated());
576             if id >= next_id {
577                 self.next_stream_id = id.next_id();
578             }
579         }
580     }
581 
is_extended_connect_protocol_enabled(&self) -> bool582     pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
583         self.is_extended_connect_protocol_enabled
584     }
585 }
586