1 //! Handle WebSocket connections.
2 //!
3 //! # Example
4 //!
5 //! ```
6 //! use axum::{
7 //! extract::ws::{WebSocketUpgrade, WebSocket},
8 //! routing::get,
9 //! response::{IntoResponse, Response},
10 //! Router,
11 //! };
12 //!
13 //! let app = Router::new().route("/ws", get(handler));
14 //!
15 //! async fn handler(ws: WebSocketUpgrade) -> Response {
16 //! ws.on_upgrade(handle_socket)
17 //! }
18 //!
19 //! async fn handle_socket(mut socket: WebSocket) {
20 //! while let Some(msg) = socket.recv().await {
21 //! let msg = if let Ok(msg) = msg {
22 //! msg
23 //! } else {
24 //! // client disconnected
25 //! return;
26 //! };
27 //!
28 //! if socket.send(msg).await.is_err() {
29 //! // client disconnected
30 //! return;
31 //! }
32 //! }
33 //! }
34 //! # async {
35 //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
36 //! # };
37 //! ```
38 //!
39 //! # Passing data and/or state to an `on_upgrade` callback
40 //!
41 //! ```
42 //! use axum::{
43 //! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
44 //! response::Response,
45 //! routing::get,
46 //! Router,
47 //! };
48 //!
49 //! #[derive(Clone)]
50 //! struct AppState {
51 //! // ...
52 //! }
53 //!
54 //! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
55 //! ws.on_upgrade(|socket| handle_socket(socket, state))
56 //! }
57 //!
58 //! async fn handle_socket(socket: WebSocket, state: AppState) {
59 //! // ...
60 //! }
61 //!
62 //! let app = Router::new()
63 //! .route("/ws", get(handler))
64 //! .with_state(AppState { /* ... */ });
65 //! # async {
66 //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
67 //! # };
68 //! ```
69 //!
70 //! # Read and write concurrently
71 //!
72 //! If you need to read and write concurrently from a [`WebSocket`] you can use
73 //! [`StreamExt::split`]:
74 //!
75 //! ```rust,no_run
76 //! use axum::{Error, extract::ws::{WebSocket, Message}};
77 //! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
78 //!
79 //! async fn handle_socket(mut socket: WebSocket) {
80 //! let (mut sender, mut receiver) = socket.split();
81 //!
82 //! tokio::spawn(write(sender));
83 //! tokio::spawn(read(receiver));
84 //! }
85 //!
86 //! async fn read(receiver: SplitStream<WebSocket>) {
87 //! // ...
88 //! }
89 //!
90 //! async fn write(sender: SplitSink<WebSocket, Message>) {
91 //! // ...
92 //! }
93 //! ```
94 //!
95 //! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
96
97 use self::rejection::*;
98 use super::FromRequestParts;
99 use crate::{
100 body::{self, Bytes},
101 response::Response,
102 Error,
103 };
104 use async_trait::async_trait;
105 use futures_util::{
106 sink::{Sink, SinkExt},
107 stream::{Stream, StreamExt},
108 };
109 use http::{
110 header::{self, HeaderMap, HeaderName, HeaderValue},
111 request::Parts,
112 Method, StatusCode,
113 };
114 use hyper::upgrade::{OnUpgrade, Upgraded};
115 use sha1::{Digest, Sha1};
116 use std::{
117 borrow::Cow,
118 future::Future,
119 pin::Pin,
120 task::{Context, Poll},
121 };
122 use tokio_tungstenite::{
123 tungstenite::{
124 self as ts,
125 protocol::{self, WebSocketConfig},
126 },
127 WebSocketStream,
128 };
129
130 /// Extractor for establishing WebSocket connections.
131 ///
132 /// Note: This extractor requires the request method to be `GET` so it should
133 /// always be used with [`get`](crate::routing::get). Requests with other methods will be
134 /// rejected.
135 ///
136 /// See the [module docs](self) for an example.
137 #[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
138 pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
139 config: WebSocketConfig,
140 /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
141 protocol: Option<HeaderValue>,
142 sec_websocket_key: HeaderValue,
143 on_upgrade: OnUpgrade,
144 on_failed_upgrade: F,
145 sec_websocket_protocol: Option<HeaderValue>,
146 }
147
148 impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("WebSocketUpgrade")
151 .field("config", &self.config)
152 .field("protocol", &self.protocol)
153 .field("sec_websocket_key", &self.sec_websocket_key)
154 .field("sec_websocket_protocol", &self.sec_websocket_protocol)
155 .finish_non_exhaustive()
156 }
157 }
158
159 impl<F> WebSocketUpgrade<F> {
160 /// Does nothing, instead use `max_write_buffer_size`.
161 #[deprecated]
max_send_queue(self, _: usize) -> Self162 pub fn max_send_queue(self, _: usize) -> Self {
163 self
164 }
165
166 /// The target minimum size of the write buffer to reach before writing the data
167 /// to the underlying stream.
168 ///
169 /// The default value is 128 KiB.
170 ///
171 /// If set to `0` each message will be eagerly written to the underlying stream.
172 /// It is often more optimal to allow them to buffer a little, hence the default value.
173 ///
174 /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
write_buffer_size(mut self, size: usize) -> Self175 pub fn write_buffer_size(mut self, size: usize) -> Self {
176 self.config.write_buffer_size = size;
177 self
178 }
179
180 /// The max size of the write buffer in bytes. Setting this can provide backpressure
181 /// in the case the write buffer is filling up due to write errors.
182 ///
183 /// The default value is unlimited.
184 ///
185 /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
186 /// when writes to the underlying stream are failing. So the **write buffer can not
187 /// fill up if you are not observing write errors even if not flushing**.
188 ///
189 /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
190 /// and probably a little more depending on error handling strategy.
max_write_buffer_size(mut self, max: usize) -> Self191 pub fn max_write_buffer_size(mut self, max: usize) -> Self {
192 self.config.max_write_buffer_size = max;
193 self
194 }
195
196 /// Set the maximum message size (defaults to 64 megabytes)
max_message_size(mut self, max: usize) -> Self197 pub fn max_message_size(mut self, max: usize) -> Self {
198 self.config.max_message_size = Some(max);
199 self
200 }
201
202 /// Set the maximum frame size (defaults to 16 megabytes)
max_frame_size(mut self, max: usize) -> Self203 pub fn max_frame_size(mut self, max: usize) -> Self {
204 self.config.max_frame_size = Some(max);
205 self
206 }
207
208 /// Allow server to accept unmasked frames (defaults to false)
accept_unmasked_frames(mut self, accept: bool) -> Self209 pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
210 self.config.accept_unmasked_frames = accept;
211 self
212 }
213
214 /// Set the known protocols.
215 ///
216 /// If the protocol name specified by `Sec-WebSocket-Protocol` header
217 /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
218 /// return the protocol name.
219 ///
220 /// The protocols should be listed in decreasing order of preference: if the client offers
221 /// multiple protocols that the server could support, the server will pick the first one in
222 /// this list.
223 ///
224 /// # Examples
225 ///
226 /// ```
227 /// use axum::{
228 /// extract::ws::{WebSocketUpgrade, WebSocket},
229 /// routing::get,
230 /// response::{IntoResponse, Response},
231 /// Router,
232 /// };
233 ///
234 /// let app = Router::new().route("/ws", get(handler));
235 ///
236 /// async fn handler(ws: WebSocketUpgrade) -> Response {
237 /// ws.protocols(["graphql-ws", "graphql-transport-ws"])
238 /// .on_upgrade(|socket| async {
239 /// // ...
240 /// })
241 /// }
242 /// # async {
243 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
244 /// # };
245 /// ```
protocols<I>(mut self, protocols: I) -> Self where I: IntoIterator, I::Item: Into<Cow<'static, str>>,246 pub fn protocols<I>(mut self, protocols: I) -> Self
247 where
248 I: IntoIterator,
249 I::Item: Into<Cow<'static, str>>,
250 {
251 if let Some(req_protocols) = self
252 .sec_websocket_protocol
253 .as_ref()
254 .and_then(|p| p.to_str().ok())
255 {
256 self.protocol = protocols
257 .into_iter()
258 // FIXME: This will often allocate a new `String` and so is less efficient than it
259 // could be. But that can't be fixed without breaking changes to the public API.
260 .map(Into::into)
261 .find(|protocol| {
262 req_protocols
263 .split(',')
264 .any(|req_protocol| req_protocol.trim() == protocol)
265 })
266 .map(|protocol| match protocol {
267 Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
268 Cow::Borrowed(s) => HeaderValue::from_static(s),
269 });
270 }
271
272 self
273 }
274
275 /// Provide a callback to call if upgrading the connection fails.
276 ///
277 /// The connection upgrade is performed in a background task. If that fails this callback
278 /// will be called.
279 ///
280 /// By default any errors will be silently ignored.
281 ///
282 /// # Example
283 ///
284 /// ```
285 /// use axum::{
286 /// extract::{WebSocketUpgrade},
287 /// response::Response,
288 /// };
289 ///
290 /// async fn handler(ws: WebSocketUpgrade) -> Response {
291 /// ws.on_failed_upgrade(|error| {
292 /// report_error(error);
293 /// })
294 /// .on_upgrade(|socket| async { /* ... */ })
295 /// }
296 /// #
297 /// # fn report_error(_: axum::Error) {}
298 /// ```
on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C> where C: OnFailedUpdgrade,299 pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
300 where
301 C: OnFailedUpdgrade,
302 {
303 WebSocketUpgrade {
304 config: self.config,
305 protocol: self.protocol,
306 sec_websocket_key: self.sec_websocket_key,
307 on_upgrade: self.on_upgrade,
308 on_failed_upgrade: callback,
309 sec_websocket_protocol: self.sec_websocket_protocol,
310 }
311 }
312
313 /// Finalize upgrading the connection and call the provided callback with
314 /// the stream.
315 #[must_use = "to setup the WebSocket connection, this response must be returned"]
on_upgrade<C, Fut>(self, callback: C) -> Response where C: FnOnce(WebSocket) -> Fut + Send + 'static, Fut: Future<Output = ()> + Send + 'static, F: OnFailedUpdgrade,316 pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
317 where
318 C: FnOnce(WebSocket) -> Fut + Send + 'static,
319 Fut: Future<Output = ()> + Send + 'static,
320 F: OnFailedUpdgrade,
321 {
322 let on_upgrade = self.on_upgrade;
323 let config = self.config;
324 let on_failed_upgrade = self.on_failed_upgrade;
325
326 let protocol = self.protocol.clone();
327
328 tokio::spawn(async move {
329 let upgraded = match on_upgrade.await {
330 Ok(upgraded) => upgraded,
331 Err(err) => {
332 on_failed_upgrade.call(Error::new(err));
333 return;
334 }
335 };
336
337 let socket =
338 WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
339 .await;
340 let socket = WebSocket {
341 inner: socket,
342 protocol,
343 };
344 callback(socket).await;
345 });
346
347 #[allow(clippy::declare_interior_mutable_const)]
348 const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
349 #[allow(clippy::declare_interior_mutable_const)]
350 const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
351
352 let mut builder = Response::builder()
353 .status(StatusCode::SWITCHING_PROTOCOLS)
354 .header(header::CONNECTION, UPGRADE)
355 .header(header::UPGRADE, WEBSOCKET)
356 .header(
357 header::SEC_WEBSOCKET_ACCEPT,
358 sign(self.sec_websocket_key.as_bytes()),
359 );
360
361 if let Some(protocol) = self.protocol {
362 builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
363 }
364
365 builder.body(body::boxed(body::Empty::new())).unwrap()
366 }
367 }
368
369 /// What to do when a connection upgrade fails.
370 ///
371 /// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
372 pub trait OnFailedUpdgrade: Send + 'static {
373 /// Call the callback.
call(self, error: Error)374 fn call(self, error: Error);
375 }
376
377 impl<F> OnFailedUpdgrade for F
378 where
379 F: FnOnce(Error) + Send + 'static,
380 {
call(self, error: Error)381 fn call(self, error: Error) {
382 self(error)
383 }
384 }
385
386 /// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`.
387 ///
388 /// It simply ignores the error.
389 #[non_exhaustive]
390 #[derive(Debug)]
391 pub struct DefaultOnFailedUpdgrade;
392
393 impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
394 #[inline]
call(self, _error: Error)395 fn call(self, _error: Error) {}
396 }
397
398 #[async_trait]
399 impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpdgrade>
400 where
401 S: Send + Sync,
402 {
403 type Rejection = WebSocketUpgradeRejection;
404
from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>405 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
406 if parts.method != Method::GET {
407 return Err(MethodNotGet.into());
408 }
409
410 if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
411 return Err(InvalidConnectionHeader.into());
412 }
413
414 if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
415 return Err(InvalidUpgradeHeader.into());
416 }
417
418 if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
419 return Err(InvalidWebSocketVersionHeader.into());
420 }
421
422 let sec_websocket_key = parts
423 .headers
424 .get(header::SEC_WEBSOCKET_KEY)
425 .ok_or(WebSocketKeyHeaderMissing)?
426 .clone();
427
428 let on_upgrade = parts
429 .extensions
430 .remove::<OnUpgrade>()
431 .ok_or(ConnectionNotUpgradable)?;
432
433 let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
434
435 Ok(Self {
436 config: Default::default(),
437 protocol: None,
438 sec_websocket_key,
439 on_upgrade,
440 sec_websocket_protocol,
441 on_failed_upgrade: DefaultOnFailedUpdgrade,
442 })
443 }
444 }
445
header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool446 fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
447 if let Some(header) = headers.get(&key) {
448 header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
449 } else {
450 false
451 }
452 }
453
header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool454 fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
455 let header = if let Some(header) = headers.get(&key) {
456 header
457 } else {
458 return false;
459 };
460
461 if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
462 header.to_ascii_lowercase().contains(value)
463 } else {
464 false
465 }
466 }
467
468 /// A stream of WebSocket messages.
469 ///
470 /// See [the module level documentation](self) for more details.
471 #[derive(Debug)]
472 pub struct WebSocket {
473 inner: WebSocketStream<Upgraded>,
474 protocol: Option<HeaderValue>,
475 }
476
477 impl WebSocket {
478 /// Receive another message.
479 ///
480 /// Returns `None` if the stream has closed.
recv(&mut self) -> Option<Result<Message, Error>>481 pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
482 self.next().await
483 }
484
485 /// Send a message.
send(&mut self, msg: Message) -> Result<(), Error>486 pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
487 self.inner
488 .send(msg.into_tungstenite())
489 .await
490 .map_err(Error::new)
491 }
492
493 /// Gracefully close this WebSocket.
close(mut self) -> Result<(), Error>494 pub async fn close(mut self) -> Result<(), Error> {
495 self.inner.close(None).await.map_err(Error::new)
496 }
497
498 /// Return the selected WebSocket subprotocol, if one has been chosen.
protocol(&self) -> Option<&HeaderValue>499 pub fn protocol(&self) -> Option<&HeaderValue> {
500 self.protocol.as_ref()
501 }
502 }
503
504 impl Stream for WebSocket {
505 type Item = Result<Message, Error>;
506
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>507 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
508 loop {
509 match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
510 Some(Ok(msg)) => {
511 if let Some(msg) = Message::from_tungstenite(msg) {
512 return Poll::Ready(Some(Ok(msg)));
513 }
514 }
515 Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
516 None => return Poll::Ready(None),
517 }
518 }
519 }
520 }
521
522 impl Sink<Message> for WebSocket {
523 type Error = Error;
524
poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>525 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
526 Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
527 }
528
start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error>529 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
530 Pin::new(&mut self.inner)
531 .start_send(item.into_tungstenite())
532 .map_err(Error::new)
533 }
534
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>535 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
536 Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
537 }
538
poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>539 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
540 Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
541 }
542 }
543
544 /// Status code used to indicate why an endpoint is closing the WebSocket connection.
545 pub type CloseCode = u16;
546
547 /// A struct representing the close command.
548 #[derive(Debug, Clone, Eq, PartialEq)]
549 pub struct CloseFrame<'t> {
550 /// The reason as a code.
551 pub code: CloseCode,
552 /// The reason as text string.
553 pub reason: Cow<'t, str>,
554 }
555
556 /// A WebSocket message.
557 //
558 // This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
559 // Copyright (c) 2017 Alexey Galakhov
560 // Copyright (c) 2016 Jason Housley
561 //
562 // Permission is hereby granted, free of charge, to any person obtaining a copy
563 // of this software and associated documentation files (the "Software"), to deal
564 // in the Software without restriction, including without limitation the rights
565 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
566 // copies of the Software, and to permit persons to whom the Software is
567 // furnished to do so, subject to the following conditions:
568 //
569 // The above copyright notice and this permission notice shall be included in
570 // all copies or substantial portions of the Software.
571 //
572 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
573 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
574 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
575 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
576 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
577 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
578 // THE SOFTWARE.
579 #[derive(Debug, Eq, PartialEq, Clone)]
580 pub enum Message {
581 /// A text WebSocket message
582 Text(String),
583 /// A binary WebSocket message
584 Binary(Vec<u8>),
585 /// A ping message with the specified payload
586 ///
587 /// The payload here must have a length less than 125 bytes.
588 ///
589 /// Ping messages will be automatically responded to by the server, so you do not have to worry
590 /// about dealing with them yourself.
591 Ping(Vec<u8>),
592 /// A pong message with the specified payload
593 ///
594 /// The payload here must have a length less than 125 bytes.
595 ///
596 /// Pong messages will be automatically sent to the client if a ping message is received, so
597 /// you do not have to worry about constructing them yourself unless you want to implement a
598 /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
599 Pong(Vec<u8>),
600 /// A close message with the optional close frame.
601 Close(Option<CloseFrame<'static>>),
602 }
603
604 impl Message {
into_tungstenite(self) -> ts::Message605 fn into_tungstenite(self) -> ts::Message {
606 match self {
607 Self::Text(text) => ts::Message::Text(text),
608 Self::Binary(binary) => ts::Message::Binary(binary),
609 Self::Ping(ping) => ts::Message::Ping(ping),
610 Self::Pong(pong) => ts::Message::Pong(pong),
611 Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
612 code: ts::protocol::frame::coding::CloseCode::from(close.code),
613 reason: close.reason,
614 })),
615 Self::Close(None) => ts::Message::Close(None),
616 }
617 }
618
from_tungstenite(message: ts::Message) -> Option<Self>619 fn from_tungstenite(message: ts::Message) -> Option<Self> {
620 match message {
621 ts::Message::Text(text) => Some(Self::Text(text)),
622 ts::Message::Binary(binary) => Some(Self::Binary(binary)),
623 ts::Message::Ping(ping) => Some(Self::Ping(ping)),
624 ts::Message::Pong(pong) => Some(Self::Pong(pong)),
625 ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
626 code: close.code.into(),
627 reason: close.reason,
628 }))),
629 ts::Message::Close(None) => Some(Self::Close(None)),
630 // we can ignore `Frame` frames as recommended by the tungstenite maintainers
631 // https://github.com/snapview/tungstenite-rs/issues/268
632 ts::Message::Frame(_) => None,
633 }
634 }
635
636 /// Consume the WebSocket and return it as binary data.
into_data(self) -> Vec<u8>637 pub fn into_data(self) -> Vec<u8> {
638 match self {
639 Self::Text(string) => string.into_bytes(),
640 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
641 Self::Close(None) => Vec::new(),
642 Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
643 }
644 }
645
646 /// Attempt to consume the WebSocket message and convert it to a String.
into_text(self) -> Result<String, Error>647 pub fn into_text(self) -> Result<String, Error> {
648 match self {
649 Self::Text(string) => Ok(string),
650 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
651 .map_err(|err| err.utf8_error())
652 .map_err(Error::new)?),
653 Self::Close(None) => Ok(String::new()),
654 Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
655 }
656 }
657
658 /// Attempt to get a &str from the WebSocket message,
659 /// this will try to convert binary data to utf8.
to_text(&self) -> Result<&str, Error>660 pub fn to_text(&self) -> Result<&str, Error> {
661 match *self {
662 Self::Text(ref string) => Ok(string),
663 Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
664 Ok(std::str::from_utf8(data).map_err(Error::new)?)
665 }
666 Self::Close(None) => Ok(""),
667 Self::Close(Some(ref frame)) => Ok(&frame.reason),
668 }
669 }
670 }
671
672 impl From<String> for Message {
from(string: String) -> Self673 fn from(string: String) -> Self {
674 Message::Text(string)
675 }
676 }
677
678 impl<'s> From<&'s str> for Message {
from(string: &'s str) -> Self679 fn from(string: &'s str) -> Self {
680 Message::Text(string.into())
681 }
682 }
683
684 impl<'b> From<&'b [u8]> for Message {
from(data: &'b [u8]) -> Self685 fn from(data: &'b [u8]) -> Self {
686 Message::Binary(data.into())
687 }
688 }
689
690 impl From<Vec<u8>> for Message {
from(data: Vec<u8>) -> Self691 fn from(data: Vec<u8>) -> Self {
692 Message::Binary(data)
693 }
694 }
695
696 impl From<Message> for Vec<u8> {
from(msg: Message) -> Self697 fn from(msg: Message) -> Self {
698 msg.into_data()
699 }
700 }
701
sign(key: &[u8]) -> HeaderValue702 fn sign(key: &[u8]) -> HeaderValue {
703 use base64::engine::Engine as _;
704
705 let mut sha1 = Sha1::default();
706 sha1.update(key);
707 sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
708 let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
709 HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
710 }
711
712 pub mod rejection {
713 //! WebSocket specific rejections.
714
715 use axum_core::__composite_rejection as composite_rejection;
716 use axum_core::__define_rejection as define_rejection;
717
718 define_rejection! {
719 #[status = METHOD_NOT_ALLOWED]
720 #[body = "Request method must be `GET`"]
721 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
722 pub struct MethodNotGet;
723 }
724
725 define_rejection! {
726 #[status = BAD_REQUEST]
727 #[body = "Connection header did not include 'upgrade'"]
728 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
729 pub struct InvalidConnectionHeader;
730 }
731
732 define_rejection! {
733 #[status = BAD_REQUEST]
734 #[body = "`Upgrade` header did not include 'websocket'"]
735 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
736 pub struct InvalidUpgradeHeader;
737 }
738
739 define_rejection! {
740 #[status = BAD_REQUEST]
741 #[body = "`Sec-WebSocket-Version` header did not include '13'"]
742 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
743 pub struct InvalidWebSocketVersionHeader;
744 }
745
746 define_rejection! {
747 #[status = BAD_REQUEST]
748 #[body = "`Sec-WebSocket-Key` header missing"]
749 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
750 pub struct WebSocketKeyHeaderMissing;
751 }
752
753 define_rejection! {
754 #[status = UPGRADE_REQUIRED]
755 #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
756 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
757 ///
758 /// This rejection is returned if the connection cannot be upgraded for example if the
759 /// request is HTTP/1.0.
760 ///
761 /// See [MDN] for more details about connection upgrades.
762 ///
763 /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
764 pub struct ConnectionNotUpgradable;
765 }
766
767 composite_rejection! {
768 /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
769 ///
770 /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
771 /// extractor can fail.
772 pub enum WebSocketUpgradeRejection {
773 MethodNotGet,
774 InvalidConnectionHeader,
775 InvalidUpgradeHeader,
776 InvalidWebSocketVersionHeader,
777 WebSocketKeyHeaderMissing,
778 ConnectionNotUpgradable,
779 }
780 }
781 }
782
783 pub mod close_code {
784 //! Constants for [`CloseCode`]s.
785 //!
786 //! [`CloseCode`]: super::CloseCode
787
788 /// Indicates a normal closure, meaning that the purpose for which the connection was
789 /// established has been fulfilled.
790 pub const NORMAL: u16 = 1000;
791
792 /// Indicates that an endpoint is "going away", such as a server going down or a browser having
793 /// navigated away from a page.
794 pub const AWAY: u16 = 1001;
795
796 /// Indicates that an endpoint is terminating the connection due to a protocol error.
797 pub const PROTOCOL: u16 = 1002;
798
799 /// Indicates that an endpoint is terminating the connection because it has received a type of
800 /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if
801 /// it receives a binary message).
802 pub const UNSUPPORTED: u16 = 1003;
803
804 /// Indicates that no status code was included in a closing frame.
805 pub const STATUS: u16 = 1005;
806
807 /// Indicates an abnormal closure.
808 pub const ABNORMAL: u16 = 1006;
809
810 /// Indicates that an endpoint is terminating the connection because it has received data
811 /// within a message that was not consistent with the type of the message (e.g., non-UTF-8
812 /// RFC3629 data within a text message).
813 pub const INVALID: u16 = 1007;
814
815 /// Indicates that an endpoint is terminating the connection because it has received a message
816 /// that violates its policy. This is a generic status code that can be returned when there is
817 /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
818 /// hide specific details about the policy.
819 pub const POLICY: u16 = 1008;
820
821 /// Indicates that an endpoint is terminating the connection because it has received a message
822 /// that is too big for it to process.
823 pub const SIZE: u16 = 1009;
824
825 /// Indicates that an endpoint (client) is terminating the connection because it has expected
826 /// the server to negotiate one or more extension, but the server didn't return them in the
827 /// response message of the WebSocket handshake. The list of extensions that are needed should
828 /// be given as the reason for closing. Note that this status code is not used by the server,
829 /// because it can fail the WebSocket handshake instead.
830 pub const EXTENSION: u16 = 1010;
831
832 /// Indicates that a server is terminating the connection because it encountered an unexpected
833 /// condition that prevented it from fulfilling the request.
834 pub const ERROR: u16 = 1011;
835
836 /// Indicates that the server is restarting.
837 pub const RESTART: u16 = 1012;
838
839 /// Indicates that the server is overloaded and the client should either connect to a different
840 /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
841 /// action.
842 pub const AGAIN: u16 = 1013;
843 }
844
845 #[cfg(test)]
846 mod tests {
847 use super::*;
848 use crate::{body::Body, routing::get, Router};
849 use http::{Request, Version};
850 use tower::ServiceExt;
851
852 #[crate::test]
rejects_http_1_0_requests()853 async fn rejects_http_1_0_requests() {
854 let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
855 let rejection = ws.unwrap_err();
856 assert!(matches!(
857 rejection,
858 WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
859 ));
860 std::future::ready(())
861 });
862
863 let req = Request::builder()
864 .version(Version::HTTP_10)
865 .method(Method::GET)
866 .header("upgrade", "websocket")
867 .header("connection", "Upgrade")
868 .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
869 .header("sec-websocket-version", "13")
870 .body(Body::empty())
871 .unwrap();
872
873 let res = svc.oneshot(req).await.unwrap();
874
875 assert_eq!(res.status(), StatusCode::OK);
876 }
877
878 #[allow(dead_code)]
default_on_failed_upgrade()879 fn default_on_failed_upgrade() {
880 async fn handler(ws: WebSocketUpgrade) -> Response {
881 ws.on_upgrade(|_| async {})
882 }
883 let _: Router = Router::new().route("/", get(handler));
884 }
885
886 #[allow(dead_code)]
on_failed_upgrade()887 fn on_failed_upgrade() {
888 async fn handler(ws: WebSocketUpgrade) -> Response {
889 ws.on_failed_upgrade(|_error: Error| println!("oops!"))
890 .on_upgrade(|_| async {})
891 }
892 let _: Router = Router::new().route("/", get(handler));
893 }
894 }
895