1 //! Handle WebSocket connections.
2 //!
3 //! # Example
4 //!
5 //! ```
6 //! use axum::{
7 //!     extract::ws::{WebSocketUpgrade, WebSocket},
8 //!     routing::get,
9 //!     response::{IntoResponse, Response},
10 //!     Router,
11 //! };
12 //!
13 //! let app = Router::new().route("/ws", get(handler));
14 //!
15 //! async fn handler(ws: WebSocketUpgrade) -> Response {
16 //!     ws.on_upgrade(handle_socket)
17 //! }
18 //!
19 //! async fn handle_socket(mut socket: WebSocket) {
20 //!     while let Some(msg) = socket.recv().await {
21 //!         let msg = if let Ok(msg) = msg {
22 //!             msg
23 //!         } else {
24 //!             // client disconnected
25 //!             return;
26 //!         };
27 //!
28 //!         if socket.send(msg).await.is_err() {
29 //!             // client disconnected
30 //!             return;
31 //!         }
32 //!     }
33 //! }
34 //! # async {
35 //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
36 //! # };
37 //! ```
38 //!
39 //! # Passing data and/or state to an `on_upgrade` callback
40 //!
41 //! ```
42 //! use axum::{
43 //!     extract::{ws::{WebSocketUpgrade, WebSocket}, State},
44 //!     response::Response,
45 //!     routing::get,
46 //!     Router,
47 //! };
48 //!
49 //! #[derive(Clone)]
50 //! struct AppState {
51 //!     // ...
52 //! }
53 //!
54 //! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
55 //!     ws.on_upgrade(|socket| handle_socket(socket, state))
56 //! }
57 //!
58 //! async fn handle_socket(socket: WebSocket, state: AppState) {
59 //!     // ...
60 //! }
61 //!
62 //! let app = Router::new()
63 //!     .route("/ws", get(handler))
64 //!     .with_state(AppState { /* ... */ });
65 //! # async {
66 //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
67 //! # };
68 //! ```
69 //!
70 //! # Read and write concurrently
71 //!
72 //! If you need to read and write concurrently from a [`WebSocket`] you can use
73 //! [`StreamExt::split`]:
74 //!
75 //! ```rust,no_run
76 //! use axum::{Error, extract::ws::{WebSocket, Message}};
77 //! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
78 //!
79 //! async fn handle_socket(mut socket: WebSocket) {
80 //!     let (mut sender, mut receiver) = socket.split();
81 //!
82 //!     tokio::spawn(write(sender));
83 //!     tokio::spawn(read(receiver));
84 //! }
85 //!
86 //! async fn read(receiver: SplitStream<WebSocket>) {
87 //!     // ...
88 //! }
89 //!
90 //! async fn write(sender: SplitSink<WebSocket, Message>) {
91 //!     // ...
92 //! }
93 //! ```
94 //!
95 //! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
96 
97 use self::rejection::*;
98 use super::FromRequestParts;
99 use crate::{
100     body::{self, Bytes},
101     response::Response,
102     Error,
103 };
104 use async_trait::async_trait;
105 use futures_util::{
106     sink::{Sink, SinkExt},
107     stream::{Stream, StreamExt},
108 };
109 use http::{
110     header::{self, HeaderMap, HeaderName, HeaderValue},
111     request::Parts,
112     Method, StatusCode,
113 };
114 use hyper::upgrade::{OnUpgrade, Upgraded};
115 use sha1::{Digest, Sha1};
116 use std::{
117     borrow::Cow,
118     future::Future,
119     pin::Pin,
120     task::{Context, Poll},
121 };
122 use tokio_tungstenite::{
123     tungstenite::{
124         self as ts,
125         protocol::{self, WebSocketConfig},
126     },
127     WebSocketStream,
128 };
129 
130 /// Extractor for establishing WebSocket connections.
131 ///
132 /// Note: This extractor requires the request method to be `GET` so it should
133 /// always be used with [`get`](crate::routing::get). Requests with other methods will be
134 /// rejected.
135 ///
136 /// See the [module docs](self) for an example.
137 #[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
138 pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
139     config: WebSocketConfig,
140     /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
141     protocol: Option<HeaderValue>,
142     sec_websocket_key: HeaderValue,
143     on_upgrade: OnUpgrade,
144     on_failed_upgrade: F,
145     sec_websocket_protocol: Option<HeaderValue>,
146 }
147 
148 impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result149     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150         f.debug_struct("WebSocketUpgrade")
151             .field("config", &self.config)
152             .field("protocol", &self.protocol)
153             .field("sec_websocket_key", &self.sec_websocket_key)
154             .field("sec_websocket_protocol", &self.sec_websocket_protocol)
155             .finish_non_exhaustive()
156     }
157 }
158 
159 impl<F> WebSocketUpgrade<F> {
160     /// Does nothing, instead use `max_write_buffer_size`.
161     #[deprecated]
max_send_queue(self, _: usize) -> Self162     pub fn max_send_queue(self, _: usize) -> Self {
163         self
164     }
165 
166     /// The target minimum size of the write buffer to reach before writing the data
167     /// to the underlying stream.
168     ///
169     /// The default value is 128 KiB.
170     ///
171     /// If set to `0` each message will be eagerly written to the underlying stream.
172     /// It is often more optimal to allow them to buffer a little, hence the default value.
173     ///
174     /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
write_buffer_size(mut self, size: usize) -> Self175     pub fn write_buffer_size(mut self, size: usize) -> Self {
176         self.config.write_buffer_size = size;
177         self
178     }
179 
180     /// The max size of the write buffer in bytes. Setting this can provide backpressure
181     /// in the case the write buffer is filling up due to write errors.
182     ///
183     /// The default value is unlimited.
184     ///
185     /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
186     /// when writes to the underlying stream are failing. So the **write buffer can not
187     /// fill up if you are not observing write errors even if not flushing**.
188     ///
189     /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
190     /// and probably a little more depending on error handling strategy.
max_write_buffer_size(mut self, max: usize) -> Self191     pub fn max_write_buffer_size(mut self, max: usize) -> Self {
192         self.config.max_write_buffer_size = max;
193         self
194     }
195 
196     /// Set the maximum message size (defaults to 64 megabytes)
max_message_size(mut self, max: usize) -> Self197     pub fn max_message_size(mut self, max: usize) -> Self {
198         self.config.max_message_size = Some(max);
199         self
200     }
201 
202     /// Set the maximum frame size (defaults to 16 megabytes)
max_frame_size(mut self, max: usize) -> Self203     pub fn max_frame_size(mut self, max: usize) -> Self {
204         self.config.max_frame_size = Some(max);
205         self
206     }
207 
208     /// Allow server to accept unmasked frames (defaults to false)
accept_unmasked_frames(mut self, accept: bool) -> Self209     pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
210         self.config.accept_unmasked_frames = accept;
211         self
212     }
213 
214     /// Set the known protocols.
215     ///
216     /// If the protocol name specified by `Sec-WebSocket-Protocol` header
217     /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
218     /// return the protocol name.
219     ///
220     /// The protocols should be listed in decreasing order of preference: if the client offers
221     /// multiple protocols that the server could support, the server will pick the first one in
222     /// this list.
223     ///
224     /// # Examples
225     ///
226     /// ```
227     /// use axum::{
228     ///     extract::ws::{WebSocketUpgrade, WebSocket},
229     ///     routing::get,
230     ///     response::{IntoResponse, Response},
231     ///     Router,
232     /// };
233     ///
234     /// let app = Router::new().route("/ws", get(handler));
235     ///
236     /// async fn handler(ws: WebSocketUpgrade) -> Response {
237     ///     ws.protocols(["graphql-ws", "graphql-transport-ws"])
238     ///         .on_upgrade(|socket| async {
239     ///             // ...
240     ///         })
241     /// }
242     /// # async {
243     /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
244     /// # };
245     /// ```
protocols<I>(mut self, protocols: I) -> Self where I: IntoIterator, I::Item: Into<Cow<'static, str>>,246     pub fn protocols<I>(mut self, protocols: I) -> Self
247     where
248         I: IntoIterator,
249         I::Item: Into<Cow<'static, str>>,
250     {
251         if let Some(req_protocols) = self
252             .sec_websocket_protocol
253             .as_ref()
254             .and_then(|p| p.to_str().ok())
255         {
256             self.protocol = protocols
257                 .into_iter()
258                 // FIXME: This will often allocate a new `String` and so is less efficient than it
259                 // could be. But that can't be fixed without breaking changes to the public API.
260                 .map(Into::into)
261                 .find(|protocol| {
262                     req_protocols
263                         .split(',')
264                         .any(|req_protocol| req_protocol.trim() == protocol)
265                 })
266                 .map(|protocol| match protocol {
267                     Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
268                     Cow::Borrowed(s) => HeaderValue::from_static(s),
269                 });
270         }
271 
272         self
273     }
274 
275     /// Provide a callback to call if upgrading the connection fails.
276     ///
277     /// The connection upgrade is performed in a background task. If that fails this callback
278     /// will be called.
279     ///
280     /// By default any errors will be silently ignored.
281     ///
282     /// # Example
283     ///
284     /// ```
285     /// use axum::{
286     ///     extract::{WebSocketUpgrade},
287     ///     response::Response,
288     /// };
289     ///
290     /// async fn handler(ws: WebSocketUpgrade) -> Response {
291     ///     ws.on_failed_upgrade(|error| {
292     ///         report_error(error);
293     ///     })
294     ///     .on_upgrade(|socket| async { /* ... */ })
295     /// }
296     /// #
297     /// # fn report_error(_: axum::Error) {}
298     /// ```
on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C> where C: OnFailedUpdgrade,299     pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
300     where
301         C: OnFailedUpdgrade,
302     {
303         WebSocketUpgrade {
304             config: self.config,
305             protocol: self.protocol,
306             sec_websocket_key: self.sec_websocket_key,
307             on_upgrade: self.on_upgrade,
308             on_failed_upgrade: callback,
309             sec_websocket_protocol: self.sec_websocket_protocol,
310         }
311     }
312 
313     /// Finalize upgrading the connection and call the provided callback with
314     /// the stream.
315     #[must_use = "to setup the WebSocket connection, this response must be returned"]
on_upgrade<C, Fut>(self, callback: C) -> Response where C: FnOnce(WebSocket) -> Fut + Send + 'static, Fut: Future<Output = ()> + Send + 'static, F: OnFailedUpdgrade,316     pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
317     where
318         C: FnOnce(WebSocket) -> Fut + Send + 'static,
319         Fut: Future<Output = ()> + Send + 'static,
320         F: OnFailedUpdgrade,
321     {
322         let on_upgrade = self.on_upgrade;
323         let config = self.config;
324         let on_failed_upgrade = self.on_failed_upgrade;
325 
326         let protocol = self.protocol.clone();
327 
328         tokio::spawn(async move {
329             let upgraded = match on_upgrade.await {
330                 Ok(upgraded) => upgraded,
331                 Err(err) => {
332                     on_failed_upgrade.call(Error::new(err));
333                     return;
334                 }
335             };
336 
337             let socket =
338                 WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
339                     .await;
340             let socket = WebSocket {
341                 inner: socket,
342                 protocol,
343             };
344             callback(socket).await;
345         });
346 
347         #[allow(clippy::declare_interior_mutable_const)]
348         const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
349         #[allow(clippy::declare_interior_mutable_const)]
350         const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
351 
352         let mut builder = Response::builder()
353             .status(StatusCode::SWITCHING_PROTOCOLS)
354             .header(header::CONNECTION, UPGRADE)
355             .header(header::UPGRADE, WEBSOCKET)
356             .header(
357                 header::SEC_WEBSOCKET_ACCEPT,
358                 sign(self.sec_websocket_key.as_bytes()),
359             );
360 
361         if let Some(protocol) = self.protocol {
362             builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
363         }
364 
365         builder.body(body::boxed(body::Empty::new())).unwrap()
366     }
367 }
368 
369 /// What to do when a connection upgrade fails.
370 ///
371 /// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
372 pub trait OnFailedUpdgrade: Send + 'static {
373     /// Call the callback.
call(self, error: Error)374     fn call(self, error: Error);
375 }
376 
377 impl<F> OnFailedUpdgrade for F
378 where
379     F: FnOnce(Error) + Send + 'static,
380 {
call(self, error: Error)381     fn call(self, error: Error) {
382         self(error)
383     }
384 }
385 
386 /// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
387 ///
388 /// It simply ignores the error.
389 #[non_exhaustive]
390 #[derive(Debug)]
391 pub struct DefaultOnFailedUpdgrade;
392 
393 impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
394     #[inline]
call(self, _error: Error)395     fn call(self, _error: Error) {}
396 }
397 
398 #[async_trait]
399 impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpdgrade>
400 where
401     S: Send + Sync,
402 {
403     type Rejection = WebSocketUpgradeRejection;
404 
from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>405     async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
406         if parts.method != Method::GET {
407             return Err(MethodNotGet.into());
408         }
409 
410         if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
411             return Err(InvalidConnectionHeader.into());
412         }
413 
414         if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
415             return Err(InvalidUpgradeHeader.into());
416         }
417 
418         if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
419             return Err(InvalidWebSocketVersionHeader.into());
420         }
421 
422         let sec_websocket_key = parts
423             .headers
424             .get(header::SEC_WEBSOCKET_KEY)
425             .ok_or(WebSocketKeyHeaderMissing)?
426             .clone();
427 
428         let on_upgrade = parts
429             .extensions
430             .remove::<OnUpgrade>()
431             .ok_or(ConnectionNotUpgradable)?;
432 
433         let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
434 
435         Ok(Self {
436             config: Default::default(),
437             protocol: None,
438             sec_websocket_key,
439             on_upgrade,
440             sec_websocket_protocol,
441             on_failed_upgrade: DefaultOnFailedUpdgrade,
442         })
443     }
444 }
445 
header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool446 fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
447     if let Some(header) = headers.get(&key) {
448         header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
449     } else {
450         false
451     }
452 }
453 
header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool454 fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
455     let header = if let Some(header) = headers.get(&key) {
456         header
457     } else {
458         return false;
459     };
460 
461     if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
462         header.to_ascii_lowercase().contains(value)
463     } else {
464         false
465     }
466 }
467 
468 /// A stream of WebSocket messages.
469 ///
470 /// See [the module level documentation](self) for more details.
471 #[derive(Debug)]
472 pub struct WebSocket {
473     inner: WebSocketStream<Upgraded>,
474     protocol: Option<HeaderValue>,
475 }
476 
477 impl WebSocket {
478     /// Receive another message.
479     ///
480     /// Returns `None` if the stream has closed.
recv(&mut self) -> Option<Result<Message, Error>>481     pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
482         self.next().await
483     }
484 
485     /// Send a message.
send(&mut self, msg: Message) -> Result<(), Error>486     pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
487         self.inner
488             .send(msg.into_tungstenite())
489             .await
490             .map_err(Error::new)
491     }
492 
493     /// Gracefully close this WebSocket.
close(mut self) -> Result<(), Error>494     pub async fn close(mut self) -> Result<(), Error> {
495         self.inner.close(None).await.map_err(Error::new)
496     }
497 
498     /// Return the selected WebSocket subprotocol, if one has been chosen.
protocol(&self) -> Option<&HeaderValue>499     pub fn protocol(&self) -> Option<&HeaderValue> {
500         self.protocol.as_ref()
501     }
502 }
503 
504 impl Stream for WebSocket {
505     type Item = Result<Message, Error>;
506 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>507     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
508         loop {
509             match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
510                 Some(Ok(msg)) => {
511                     if let Some(msg) = Message::from_tungstenite(msg) {
512                         return Poll::Ready(Some(Ok(msg)));
513                     }
514                 }
515                 Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
516                 None => return Poll::Ready(None),
517             }
518         }
519     }
520 }
521 
522 impl Sink<Message> for WebSocket {
523     type Error = Error;
524 
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>525     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
526         Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
527     }
528 
start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error>529     fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
530         Pin::new(&mut self.inner)
531             .start_send(item.into_tungstenite())
532             .map_err(Error::new)
533     }
534 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>535     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
536         Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
537     }
538 
poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>539     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
540         Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
541     }
542 }
543 
544 /// Status code used to indicate why an endpoint is closing the WebSocket connection.
545 pub type CloseCode = u16;
546 
547 /// A struct representing the close command.
548 #[derive(Debug, Clone, Eq, PartialEq)]
549 pub struct CloseFrame<'t> {
550     /// The reason as a code.
551     pub code: CloseCode,
552     /// The reason as text string.
553     pub reason: Cow<'t, str>,
554 }
555 
556 /// A WebSocket message.
557 //
558 // This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
559 // Copyright (c) 2017 Alexey Galakhov
560 // Copyright (c) 2016 Jason Housley
561 //
562 // Permission is hereby granted, free of charge, to any person obtaining a copy
563 // of this software and associated documentation files (the "Software"), to deal
564 // in the Software without restriction, including without limitation the rights
565 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
566 // copies of the Software, and to permit persons to whom the Software is
567 // furnished to do so, subject to the following conditions:
568 //
569 // The above copyright notice and this permission notice shall be included in
570 // all copies or substantial portions of the Software.
571 //
572 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
573 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
574 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
575 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
576 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
577 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
578 // THE SOFTWARE.
579 #[derive(Debug, Eq, PartialEq, Clone)]
580 pub enum Message {
581     /// A text WebSocket message
582     Text(String),
583     /// A binary WebSocket message
584     Binary(Vec<u8>),
585     /// A ping message with the specified payload
586     ///
587     /// The payload here must have a length less than 125 bytes.
588     ///
589     /// Ping messages will be automatically responded to by the server, so you do not have to worry
590     /// about dealing with them yourself.
591     Ping(Vec<u8>),
592     /// A pong message with the specified payload
593     ///
594     /// The payload here must have a length less than 125 bytes.
595     ///
596     /// Pong messages will be automatically sent to the client if a ping message is received, so
597     /// you do not have to worry about constructing them yourself unless you want to implement a
598     /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
599     Pong(Vec<u8>),
600     /// A close message with the optional close frame.
601     Close(Option<CloseFrame<'static>>),
602 }
603 
604 impl Message {
into_tungstenite(self) -> ts::Message605     fn into_tungstenite(self) -> ts::Message {
606         match self {
607             Self::Text(text) => ts::Message::Text(text),
608             Self::Binary(binary) => ts::Message::Binary(binary),
609             Self::Ping(ping) => ts::Message::Ping(ping),
610             Self::Pong(pong) => ts::Message::Pong(pong),
611             Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
612                 code: ts::protocol::frame::coding::CloseCode::from(close.code),
613                 reason: close.reason,
614             })),
615             Self::Close(None) => ts::Message::Close(None),
616         }
617     }
618 
from_tungstenite(message: ts::Message) -> Option<Self>619     fn from_tungstenite(message: ts::Message) -> Option<Self> {
620         match message {
621             ts::Message::Text(text) => Some(Self::Text(text)),
622             ts::Message::Binary(binary) => Some(Self::Binary(binary)),
623             ts::Message::Ping(ping) => Some(Self::Ping(ping)),
624             ts::Message::Pong(pong) => Some(Self::Pong(pong)),
625             ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
626                 code: close.code.into(),
627                 reason: close.reason,
628             }))),
629             ts::Message::Close(None) => Some(Self::Close(None)),
630             // we can ignore `Frame` frames as recommended by the tungstenite maintainers
631             // https://github.com/snapview/tungstenite-rs/issues/268
632             ts::Message::Frame(_) => None,
633         }
634     }
635 
636     /// Consume the WebSocket and return it as binary data.
into_data(self) -> Vec<u8>637     pub fn into_data(self) -> Vec<u8> {
638         match self {
639             Self::Text(string) => string.into_bytes(),
640             Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
641             Self::Close(None) => Vec::new(),
642             Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
643         }
644     }
645 
646     /// Attempt to consume the WebSocket message and convert it to a String.
into_text(self) -> Result<String, Error>647     pub fn into_text(self) -> Result<String, Error> {
648         match self {
649             Self::Text(string) => Ok(string),
650             Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
651                 .map_err(|err| err.utf8_error())
652                 .map_err(Error::new)?),
653             Self::Close(None) => Ok(String::new()),
654             Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
655         }
656     }
657 
658     /// Attempt to get a &str from the WebSocket message,
659     /// this will try to convert binary data to utf8.
to_text(&self) -> Result<&str, Error>660     pub fn to_text(&self) -> Result<&str, Error> {
661         match *self {
662             Self::Text(ref string) => Ok(string),
663             Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
664                 Ok(std::str::from_utf8(data).map_err(Error::new)?)
665             }
666             Self::Close(None) => Ok(""),
667             Self::Close(Some(ref frame)) => Ok(&frame.reason),
668         }
669     }
670 }
671 
672 impl From<String> for Message {
from(string: String) -> Self673     fn from(string: String) -> Self {
674         Message::Text(string)
675     }
676 }
677 
678 impl<'s> From<&'s str> for Message {
from(string: &'s str) -> Self679     fn from(string: &'s str) -> Self {
680         Message::Text(string.into())
681     }
682 }
683 
684 impl<'b> From<&'b [u8]> for Message {
from(data: &'b [u8]) -> Self685     fn from(data: &'b [u8]) -> Self {
686         Message::Binary(data.into())
687     }
688 }
689 
690 impl From<Vec<u8>> for Message {
from(data: Vec<u8>) -> Self691     fn from(data: Vec<u8>) -> Self {
692         Message::Binary(data)
693     }
694 }
695 
696 impl From<Message> for Vec<u8> {
from(msg: Message) -> Self697     fn from(msg: Message) -> Self {
698         msg.into_data()
699     }
700 }
701 
sign(key: &[u8]) -> HeaderValue702 fn sign(key: &[u8]) -> HeaderValue {
703     use base64::engine::Engine as _;
704 
705     let mut sha1 = Sha1::default();
706     sha1.update(key);
707     sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
708     let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
709     HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
710 }
711 
712 pub mod rejection {
713     //! WebSocket specific rejections.
714 
715     use axum_core::__composite_rejection as composite_rejection;
716     use axum_core::__define_rejection as define_rejection;
717 
718     define_rejection! {
719         #[status = METHOD_NOT_ALLOWED]
720         #[body = "Request method must be `GET`"]
721         /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
722         pub struct MethodNotGet;
723     }
724 
725     define_rejection! {
726         #[status = BAD_REQUEST]
727         #[body = "Connection header did not include 'upgrade'"]
728         /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
729         pub struct InvalidConnectionHeader;
730     }
731 
732     define_rejection! {
733         #[status = BAD_REQUEST]
734         #[body = "`Upgrade` header did not include 'websocket'"]
735         /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
736         pub struct InvalidUpgradeHeader;
737     }
738 
739     define_rejection! {
740         #[status = BAD_REQUEST]
741         #[body = "`Sec-WebSocket-Version` header did not include '13'"]
742         /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
743         pub struct InvalidWebSocketVersionHeader;
744     }
745 
746     define_rejection! {
747         #[status = BAD_REQUEST]
748         #[body = "`Sec-WebSocket-Key` header missing"]
749         /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
750         pub struct WebSocketKeyHeaderMissing;
751     }
752 
753     define_rejection! {
754         #[status = UPGRADE_REQUIRED]
755         #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
756         /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
757         ///
758         /// This rejection is returned if the connection cannot be upgraded for example if the
759         /// request is HTTP/1.0.
760         ///
761         /// See [MDN] for more details about connection upgrades.
762         ///
763         /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
764         pub struct ConnectionNotUpgradable;
765     }
766 
767     composite_rejection! {
768         /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
769         ///
770         /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
771         /// extractor can fail.
772         pub enum WebSocketUpgradeRejection {
773             MethodNotGet,
774             InvalidConnectionHeader,
775             InvalidUpgradeHeader,
776             InvalidWebSocketVersionHeader,
777             WebSocketKeyHeaderMissing,
778             ConnectionNotUpgradable,
779         }
780     }
781 }
782 
783 pub mod close_code {
784     //! Constants for [`CloseCode`]s.
785     //!
786     //! [`CloseCode`]: super::CloseCode
787 
788     /// Indicates a normal closure, meaning that the purpose for which the connection was
789     /// established has been fulfilled.
790     pub const NORMAL: u16 = 1000;
791 
792     /// Indicates that an endpoint is "going away", such as a server going down or a browser having
793     /// navigated away from a page.
794     pub const AWAY: u16 = 1001;
795 
796     /// Indicates that an endpoint is terminating the connection due to a protocol error.
797     pub const PROTOCOL: u16 = 1002;
798 
799     /// Indicates that an endpoint is terminating the connection because it has received a type of
800     /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if
801     /// it receives a binary message).
802     pub const UNSUPPORTED: u16 = 1003;
803 
804     /// Indicates that no status code was included in a closing frame.
805     pub const STATUS: u16 = 1005;
806 
807     /// Indicates an abnormal closure.
808     pub const ABNORMAL: u16 = 1006;
809 
810     /// Indicates that an endpoint is terminating the connection because it has received data
811     /// within a message that was not consistent with the type of the message (e.g., non-UTF-8
812     /// RFC3629 data within a text message).
813     pub const INVALID: u16 = 1007;
814 
815     /// Indicates that an endpoint is terminating the connection because it has received a message
816     /// that violates its policy. This is a generic status code that can be returned when there is
817     /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
818     /// hide specific details about the policy.
819     pub const POLICY: u16 = 1008;
820 
821     /// Indicates that an endpoint is terminating the connection because it has received a message
822     /// that is too big for it to process.
823     pub const SIZE: u16 = 1009;
824 
825     /// Indicates that an endpoint (client) is terminating the connection because it has expected
826     /// the server to negotiate one or more extension, but the server didn't return them in the
827     /// response message of the WebSocket handshake. The list of extensions that are needed should
828     /// be given as the reason for closing. Note that this status code is not used by the server,
829     /// because it can fail the WebSocket handshake instead.
830     pub const EXTENSION: u16 = 1010;
831 
832     /// Indicates that a server is terminating the connection because it encountered an unexpected
833     /// condition that prevented it from fulfilling the request.
834     pub const ERROR: u16 = 1011;
835 
836     /// Indicates that the server is restarting.
837     pub const RESTART: u16 = 1012;
838 
839     /// Indicates that the server is overloaded and the client should either connect to a different
840     /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
841     /// action.
842     pub const AGAIN: u16 = 1013;
843 }
844 
845 #[cfg(test)]
846 mod tests {
847     use super::*;
848     use crate::{body::Body, routing::get, Router};
849     use http::{Request, Version};
850     use tower::ServiceExt;
851 
852     #[crate::test]
rejects_http_1_0_requests()853     async fn rejects_http_1_0_requests() {
854         let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
855             let rejection = ws.unwrap_err();
856             assert!(matches!(
857                 rejection,
858                 WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
859             ));
860             std::future::ready(())
861         });
862 
863         let req = Request::builder()
864             .version(Version::HTTP_10)
865             .method(Method::GET)
866             .header("upgrade", "websocket")
867             .header("connection", "Upgrade")
868             .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
869             .header("sec-websocket-version", "13")
870             .body(Body::empty())
871             .unwrap();
872 
873         let res = svc.oneshot(req).await.unwrap();
874 
875         assert_eq!(res.status(), StatusCode::OK);
876     }
877 
878     #[allow(dead_code)]
default_on_failed_upgrade()879     fn default_on_failed_upgrade() {
880         async fn handler(ws: WebSocketUpgrade) -> Response {
881             ws.on_upgrade(|_| async {})
882         }
883         let _: Router = Router::new().route("/", get(handler));
884     }
885 
886     #[allow(dead_code)]
on_failed_upgrade()887     fn on_failed_upgrade() {
888         async fn handler(ws: WebSocketUpgrade) -> Response {
889             ws.on_failed_upgrade(|_error: Error| println!("oops!"))
890                 .on_upgrade(|_| async {})
891         }
892         let _: Router = Router::new().route("/", get(handler));
893     }
894 }
895