1 //! Route to services and handlers based on HTTP methods.
2
3 use super::{future::InfallibleRouteFuture, IntoMakeService};
4 #[cfg(feature = "tokio")]
5 use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6 use crate::{
7 body::{Body, Bytes, HttpBody},
8 boxed::BoxedIntoRoute,
9 error_handling::{HandleError, HandleErrorLayer},
10 handler::Handler,
11 http::{Method, Request, StatusCode},
12 response::Response,
13 routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14 };
15 use axum_core::response::IntoResponse;
16 use bytes::BytesMut;
17 use std::{
18 convert::Infallible,
19 fmt,
20 task::{Context, Poll},
21 };
22 use tower::{service_fn, util::MapResponseLayer};
23 use tower_layer::Layer;
24 use tower_service::Service;
25
26 macro_rules! top_level_service_fn {
27 (
28 $name:ident, GET
29 ) => {
30 top_level_service_fn!(
31 /// Route `GET` requests to the given service.
32 ///
33 /// # Example
34 ///
35 /// ```rust
36 /// use axum::{
37 /// http::Request,
38 /// Router,
39 /// routing::get_service,
40 /// };
41 /// use http::Response;
42 /// use std::convert::Infallible;
43 /// use hyper::Body;
44 ///
45 /// let service = tower::service_fn(|request: Request<Body>| async {
46 /// Ok::<_, Infallible>(Response::new(Body::empty()))
47 /// });
48 ///
49 /// // Requests to `GET /` will go to `service`.
50 /// let app = Router::new().route("/", get_service(service));
51 /// # async {
52 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
53 /// # };
54 /// ```
55 ///
56 /// Note that `get` routes will also be called for `HEAD` requests but will have
57 /// the response body removed. Make sure to add explicit `HEAD` routes
58 /// afterwards.
59 $name,
60 GET
61 );
62 };
63
64 (
65 $name:ident, $method:ident
66 ) => {
67 top_level_service_fn!(
68 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
69 ///
70 /// See [`get_service`] for an example.
71 $name,
72 $method
73 );
74 };
75
76 (
77 $(#[$m:meta])+
78 $name:ident, $method:ident
79 ) => {
80 $(#[$m])+
81 pub fn $name<T, S, B>(svc: T) -> MethodRouter<S, B, T::Error>
82 where
83 T: Service<Request<B>> + Clone + Send + 'static,
84 T::Response: IntoResponse + 'static,
85 T::Future: Send + 'static,
86 B: HttpBody + Send + 'static,
87 S: Clone,
88 {
89 on_service(MethodFilter::$method, svc)
90 }
91 };
92 }
93
94 macro_rules! top_level_handler_fn {
95 (
96 $name:ident, GET
97 ) => {
98 top_level_handler_fn!(
99 /// Route `GET` requests to the given handler.
100 ///
101 /// # Example
102 ///
103 /// ```rust
104 /// use axum::{
105 /// routing::get,
106 /// Router,
107 /// };
108 ///
109 /// async fn handler() {}
110 ///
111 /// // Requests to `GET /` will go to `handler`.
112 /// let app = Router::new().route("/", get(handler));
113 /// # async {
114 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
115 /// # };
116 /// ```
117 ///
118 /// Note that `get` routes will also be called for `HEAD` requests but will have
119 /// the response body removed. Make sure to add explicit `HEAD` routes
120 /// afterwards.
121 $name,
122 GET
123 );
124 };
125
126 (
127 $name:ident, $method:ident
128 ) => {
129 top_level_handler_fn!(
130 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
131 ///
132 /// See [`get`] for an example.
133 $name,
134 $method
135 );
136 };
137
138 (
139 $(#[$m:meta])+
140 $name:ident, $method:ident
141 ) => {
142 $(#[$m])+
143 pub fn $name<H, T, S, B>(handler: H) -> MethodRouter<S, B, Infallible>
144 where
145 H: Handler<T, S, B>,
146 B: HttpBody + Send + 'static,
147 T: 'static,
148 S: Clone + Send + Sync + 'static,
149 {
150 on(MethodFilter::$method, handler)
151 }
152 };
153 }
154
155 macro_rules! chained_service_fn {
156 (
157 $name:ident, GET
158 ) => {
159 chained_service_fn!(
160 /// Chain an additional service that will only accept `GET` requests.
161 ///
162 /// # Example
163 ///
164 /// ```rust
165 /// use axum::{
166 /// http::Request,
167 /// Router,
168 /// routing::post_service,
169 /// };
170 /// use http::Response;
171 /// use std::convert::Infallible;
172 /// use hyper::Body;
173 ///
174 /// let service = tower::service_fn(|request: Request<Body>| async {
175 /// Ok::<_, Infallible>(Response::new(Body::empty()))
176 /// });
177 ///
178 /// let other_service = tower::service_fn(|request: Request<Body>| async {
179 /// Ok::<_, Infallible>(Response::new(Body::empty()))
180 /// });
181 ///
182 /// // Requests to `POST /` will go to `service` and `GET /` will go to
183 /// // `other_service`.
184 /// let app = Router::new().route("/", post_service(service).get_service(other_service));
185 /// # async {
186 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
187 /// # };
188 /// ```
189 ///
190 /// Note that `get` routes will also be called for `HEAD` requests but will have
191 /// the response body removed. Make sure to add explicit `HEAD` routes
192 /// afterwards.
193 $name,
194 GET
195 );
196 };
197
198 (
199 $name:ident, $method:ident
200 ) => {
201 chained_service_fn!(
202 #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
203 ///
204 /// See [`MethodRouter::get_service`] for an example.
205 $name,
206 $method
207 );
208 };
209
210 (
211 $(#[$m:meta])+
212 $name:ident, $method:ident
213 ) => {
214 $(#[$m])+
215 #[track_caller]
216 pub fn $name<T>(self, svc: T) -> Self
217 where
218 T: Service<Request<B>, Error = E>
219 + Clone
220 + Send
221 + 'static,
222 T::Response: IntoResponse + 'static,
223 T::Future: Send + 'static,
224 {
225 self.on_service(MethodFilter::$method, svc)
226 }
227 };
228 }
229
230 macro_rules! chained_handler_fn {
231 (
232 $name:ident, GET
233 ) => {
234 chained_handler_fn!(
235 /// Chain an additional handler that will only accept `GET` requests.
236 ///
237 /// # Example
238 ///
239 /// ```rust
240 /// use axum::{routing::post, Router};
241 ///
242 /// async fn handler() {}
243 ///
244 /// async fn other_handler() {}
245 ///
246 /// // Requests to `POST /` will go to `handler` and `GET /` will go to
247 /// // `other_handler`.
248 /// let app = Router::new().route("/", post(handler).get(other_handler));
249 /// # async {
250 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
251 /// # };
252 /// ```
253 ///
254 /// Note that `get` routes will also be called for `HEAD` requests but will have
255 /// the response body removed. Make sure to add explicit `HEAD` routes
256 /// afterwards.
257 $name,
258 GET
259 );
260 };
261
262 (
263 $name:ident, $method:ident
264 ) => {
265 chained_handler_fn!(
266 #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
267 ///
268 /// See [`MethodRouter::get`] for an example.
269 $name,
270 $method
271 );
272 };
273
274 (
275 $(#[$m:meta])+
276 $name:ident, $method:ident
277 ) => {
278 $(#[$m])+
279 #[track_caller]
280 pub fn $name<H, T>(self, handler: H) -> Self
281 where
282 H: Handler<T, S, B>,
283 T: 'static,
284 S: Send + Sync + 'static,
285 {
286 self.on(MethodFilter::$method, handler)
287 }
288 };
289 }
290
291 top_level_service_fn!(delete_service, DELETE);
292 top_level_service_fn!(get_service, GET);
293 top_level_service_fn!(head_service, HEAD);
294 top_level_service_fn!(options_service, OPTIONS);
295 top_level_service_fn!(patch_service, PATCH);
296 top_level_service_fn!(post_service, POST);
297 top_level_service_fn!(put_service, PUT);
298 top_level_service_fn!(trace_service, TRACE);
299
300 /// Route requests with the given method to the service.
301 ///
302 /// # Example
303 ///
304 /// ```rust
305 /// use axum::{
306 /// http::Request,
307 /// routing::on,
308 /// Router,
309 /// routing::{MethodFilter, on_service},
310 /// };
311 /// use http::Response;
312 /// use std::convert::Infallible;
313 /// use hyper::Body;
314 ///
315 /// let service = tower::service_fn(|request: Request<Body>| async {
316 /// Ok::<_, Infallible>(Response::new(Body::empty()))
317 /// });
318 ///
319 /// // Requests to `POST /` will go to `service`.
320 /// let app = Router::new().route("/", on_service(MethodFilter::POST, service));
321 /// # async {
322 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
323 /// # };
324 /// ```
on_service<T, S, B>(filter: MethodFilter, svc: T) -> MethodRouter<S, B, T::Error> where T: Service<Request<B>> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, B: HttpBody + Send + 'static, S: Clone,325 pub fn on_service<T, S, B>(filter: MethodFilter, svc: T) -> MethodRouter<S, B, T::Error>
326 where
327 T: Service<Request<B>> + Clone + Send + 'static,
328 T::Response: IntoResponse + 'static,
329 T::Future: Send + 'static,
330 B: HttpBody + Send + 'static,
331 S: Clone,
332 {
333 MethodRouter::new().on_service(filter, svc)
334 }
335
336 /// Route requests to the given service regardless of its method.
337 ///
338 /// # Example
339 ///
340 /// ```rust
341 /// use axum::{
342 /// http::Request,
343 /// Router,
344 /// routing::any_service,
345 /// };
346 /// use http::Response;
347 /// use std::convert::Infallible;
348 /// use hyper::Body;
349 ///
350 /// let service = tower::service_fn(|request: Request<Body>| async {
351 /// Ok::<_, Infallible>(Response::new(Body::empty()))
352 /// });
353 ///
354 /// // All requests to `/` will go to `service`.
355 /// let app = Router::new().route("/", any_service(service));
356 /// # async {
357 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
358 /// # };
359 /// ```
360 ///
361 /// Additional methods can still be chained:
362 ///
363 /// ```rust
364 /// use axum::{
365 /// http::Request,
366 /// Router,
367 /// routing::any_service,
368 /// };
369 /// use http::Response;
370 /// use std::convert::Infallible;
371 /// use hyper::Body;
372 ///
373 /// let service = tower::service_fn(|request: Request<Body>| async {
374 /// # Ok::<_, Infallible>(Response::new(Body::empty()))
375 /// // ...
376 /// });
377 ///
378 /// let other_service = tower::service_fn(|request: Request<Body>| async {
379 /// # Ok::<_, Infallible>(Response::new(Body::empty()))
380 /// // ...
381 /// });
382 ///
383 /// // `POST /` goes to `other_service`. All other requests go to `service`
384 /// let app = Router::new().route("/", any_service(service).post_service(other_service));
385 /// # async {
386 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
387 /// # };
388 /// ```
any_service<T, S, B>(svc: T) -> MethodRouter<S, B, T::Error> where T: Service<Request<B>> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, B: HttpBody + Send + 'static, S: Clone,389 pub fn any_service<T, S, B>(svc: T) -> MethodRouter<S, B, T::Error>
390 where
391 T: Service<Request<B>> + Clone + Send + 'static,
392 T::Response: IntoResponse + 'static,
393 T::Future: Send + 'static,
394 B: HttpBody + Send + 'static,
395 S: Clone,
396 {
397 MethodRouter::new()
398 .fallback_service(svc)
399 .skip_allow_header()
400 }
401
402 top_level_handler_fn!(delete, DELETE);
403 top_level_handler_fn!(get, GET);
404 top_level_handler_fn!(head, HEAD);
405 top_level_handler_fn!(options, OPTIONS);
406 top_level_handler_fn!(patch, PATCH);
407 top_level_handler_fn!(post, POST);
408 top_level_handler_fn!(put, PUT);
409 top_level_handler_fn!(trace, TRACE);
410
411 /// Route requests with the given method to the handler.
412 ///
413 /// # Example
414 ///
415 /// ```rust
416 /// use axum::{
417 /// routing::on,
418 /// Router,
419 /// routing::MethodFilter,
420 /// };
421 ///
422 /// async fn handler() {}
423 ///
424 /// // Requests to `POST /` will go to `handler`.
425 /// let app = Router::new().route("/", on(MethodFilter::POST, handler));
426 /// # async {
427 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
428 /// # };
429 /// ```
on<H, T, S, B>(filter: MethodFilter, handler: H) -> MethodRouter<S, B, Infallible> where H: Handler<T, S, B>, B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static,430 pub fn on<H, T, S, B>(filter: MethodFilter, handler: H) -> MethodRouter<S, B, Infallible>
431 where
432 H: Handler<T, S, B>,
433 B: HttpBody + Send + 'static,
434 T: 'static,
435 S: Clone + Send + Sync + 'static,
436 {
437 MethodRouter::new().on(filter, handler)
438 }
439
440 /// Route requests with the given handler regardless of the method.
441 ///
442 /// # Example
443 ///
444 /// ```rust
445 /// use axum::{
446 /// routing::any,
447 /// Router,
448 /// };
449 ///
450 /// async fn handler() {}
451 ///
452 /// // All requests to `/` will go to `handler`.
453 /// let app = Router::new().route("/", any(handler));
454 /// # async {
455 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
456 /// # };
457 /// ```
458 ///
459 /// Additional methods can still be chained:
460 ///
461 /// ```rust
462 /// use axum::{
463 /// routing::any,
464 /// Router,
465 /// };
466 ///
467 /// async fn handler() {}
468 ///
469 /// async fn other_handler() {}
470 ///
471 /// // `POST /` goes to `other_handler`. All other requests go to `handler`
472 /// let app = Router::new().route("/", any(handler).post(other_handler));
473 /// # async {
474 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
475 /// # };
476 /// ```
any<H, T, S, B>(handler: H) -> MethodRouter<S, B, Infallible> where H: Handler<T, S, B>, B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static,477 pub fn any<H, T, S, B>(handler: H) -> MethodRouter<S, B, Infallible>
478 where
479 H: Handler<T, S, B>,
480 B: HttpBody + Send + 'static,
481 T: 'static,
482 S: Clone + Send + Sync + 'static,
483 {
484 MethodRouter::new().fallback(handler).skip_allow_header()
485 }
486
487 /// A [`Service`] that accepts requests based on a [`MethodFilter`] and
488 /// allows chaining additional handlers and services.
489 ///
490 /// # When does `MethodRouter` implement [`Service`]?
491 ///
492 /// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires.
493 ///
494 /// ```
495 /// use tower::Service;
496 /// use axum::{routing::get, extract::State, body::Body, http::Request};
497 ///
498 /// // this `MethodRouter` doesn't require any state, i.e. the state is `()`,
499 /// let method_router = get(|| async {});
500 /// // and thus it implements `Service`
501 /// assert_service(method_router);
502 ///
503 /// // this requires a `String` and doesn't implement `Service`
504 /// let method_router = get(|_: State<String>| async {});
505 /// // until you provide the `String` with `.with_state(...)`
506 /// let method_router_with_state = method_router.with_state(String::new());
507 /// // and then it implements `Service`
508 /// assert_service(method_router_with_state);
509 ///
510 /// // helper to check that a value implements `Service`
511 /// fn assert_service<S>(service: S)
512 /// where
513 /// S: Service<Request<Body>>,
514 /// {}
515 /// ```
516 #[must_use]
517 pub struct MethodRouter<S = (), B = Body, E = Infallible> {
518 get: MethodEndpoint<S, B, E>,
519 head: MethodEndpoint<S, B, E>,
520 delete: MethodEndpoint<S, B, E>,
521 options: MethodEndpoint<S, B, E>,
522 patch: MethodEndpoint<S, B, E>,
523 post: MethodEndpoint<S, B, E>,
524 put: MethodEndpoint<S, B, E>,
525 trace: MethodEndpoint<S, B, E>,
526 fallback: Fallback<S, B, E>,
527 allow_header: AllowHeader,
528 }
529
530 #[derive(Clone, Debug)]
531 enum AllowHeader {
532 /// No `Allow` header value has been built-up yet. This is the default state
533 None,
534 /// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
535 Skip,
536 /// The current value of the `Allow` header.
537 Bytes(BytesMut),
538 }
539
540 impl AllowHeader {
merge(self, other: Self) -> Self541 fn merge(self, other: Self) -> Self {
542 match (self, other) {
543 (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
544 (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
545 (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
546 (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
547 (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
548 a.extend_from_slice(b",");
549 a.extend_from_slice(&b);
550 AllowHeader::Bytes(a)
551 }
552 }
553 }
554 }
555
556 impl<S, B, E> fmt::Debug for MethodRouter<S, B, E> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result557 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
558 f.debug_struct("MethodRouter")
559 .field("get", &self.get)
560 .field("head", &self.head)
561 .field("delete", &self.delete)
562 .field("options", &self.options)
563 .field("patch", &self.patch)
564 .field("post", &self.post)
565 .field("put", &self.put)
566 .field("trace", &self.trace)
567 .field("fallback", &self.fallback)
568 .field("allow_header", &self.allow_header)
569 .finish()
570 }
571 }
572
573 impl<S, B> MethodRouter<S, B, Infallible>
574 where
575 B: HttpBody + Send + 'static,
576 S: Clone,
577 {
578 /// Chain an additional handler that will accept requests matching the given
579 /// `MethodFilter`.
580 ///
581 /// # Example
582 ///
583 /// ```rust
584 /// use axum::{
585 /// routing::get,
586 /// Router,
587 /// routing::MethodFilter
588 /// };
589 ///
590 /// async fn handler() {}
591 ///
592 /// async fn other_handler() {}
593 ///
594 /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to
595 /// // `other_handler`
596 /// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler));
597 /// # async {
598 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
599 /// # };
600 /// ```
601 #[track_caller]
on<H, T>(self, filter: MethodFilter, handler: H) -> Self where H: Handler<T, S, B>, T: 'static, S: Send + Sync + 'static,602 pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
603 where
604 H: Handler<T, S, B>,
605 T: 'static,
606 S: Send + Sync + 'static,
607 {
608 self.on_endpoint(
609 filter,
610 MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
611 )
612 }
613
614 chained_handler_fn!(delete, DELETE);
615 chained_handler_fn!(get, GET);
616 chained_handler_fn!(head, HEAD);
617 chained_handler_fn!(options, OPTIONS);
618 chained_handler_fn!(patch, PATCH);
619 chained_handler_fn!(post, POST);
620 chained_handler_fn!(put, PUT);
621 chained_handler_fn!(trace, TRACE);
622
623 /// Add a fallback [`Handler`] to the router.
fallback<H, T>(mut self, handler: H) -> Self where H: Handler<T, S, B>, T: 'static, S: Send + Sync + 'static,624 pub fn fallback<H, T>(mut self, handler: H) -> Self
625 where
626 H: Handler<T, S, B>,
627 T: 'static,
628 S: Send + Sync + 'static,
629 {
630 self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
631 self
632 }
633 }
634
635 impl<B> MethodRouter<(), B, Infallible>
636 where
637 B: HttpBody + Send + 'static,
638 {
639 /// Convert the handler into a [`MakeService`].
640 ///
641 /// This allows you to serve a single handler if you don't need any routing:
642 ///
643 /// ```rust
644 /// use axum::{
645 /// Server,
646 /// handler::Handler,
647 /// http::{Uri, Method},
648 /// response::IntoResponse,
649 /// routing::get,
650 /// };
651 /// use std::net::SocketAddr;
652 ///
653 /// async fn handler(method: Method, uri: Uri, body: String) -> String {
654 /// format!("received `{} {}` with body `{:?}`", method, uri, body)
655 /// }
656 ///
657 /// let router = get(handler).post(handler);
658 ///
659 /// # async {
660 /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
661 /// .serve(router.into_make_service())
662 /// .await?;
663 /// # Ok::<_, hyper::Error>(())
664 /// # };
665 /// ```
666 ///
667 /// [`MakeService`]: tower::make::MakeService
into_make_service(self) -> IntoMakeService<Self>668 pub fn into_make_service(self) -> IntoMakeService<Self> {
669 IntoMakeService::new(self.with_state(()))
670 }
671
672 /// Convert the router into a [`MakeService`] which stores information
673 /// about the incoming connection.
674 ///
675 /// See [`Router::into_make_service_with_connect_info`] for more details.
676 ///
677 /// ```rust
678 /// use axum::{
679 /// Server,
680 /// handler::Handler,
681 /// response::IntoResponse,
682 /// extract::ConnectInfo,
683 /// routing::get,
684 /// };
685 /// use std::net::SocketAddr;
686 ///
687 /// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
688 /// format!("Hello {}", addr)
689 /// }
690 ///
691 /// let router = get(handler).post(handler);
692 ///
693 /// # async {
694 /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
695 /// .serve(router.into_make_service_with_connect_info::<SocketAddr>())
696 /// .await?;
697 /// # Ok::<_, hyper::Error>(())
698 /// # };
699 /// ```
700 ///
701 /// [`MakeService`]: tower::make::MakeService
702 /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
703 #[cfg(feature = "tokio")]
into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>704 pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
705 IntoMakeServiceWithConnectInfo::new(self.with_state(()))
706 }
707 }
708
709 impl<S, B, E> MethodRouter<S, B, E>
710 where
711 B: HttpBody + Send + 'static,
712 S: Clone,
713 {
714 /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
715 /// requests.
new() -> Self716 pub fn new() -> Self {
717 let fallback = Route::new(service_fn(|_: Request<B>| async {
718 Ok(StatusCode::METHOD_NOT_ALLOWED.into_response())
719 }));
720
721 Self {
722 get: MethodEndpoint::None,
723 head: MethodEndpoint::None,
724 delete: MethodEndpoint::None,
725 options: MethodEndpoint::None,
726 patch: MethodEndpoint::None,
727 post: MethodEndpoint::None,
728 put: MethodEndpoint::None,
729 trace: MethodEndpoint::None,
730 allow_header: AllowHeader::None,
731 fallback: Fallback::Default(fallback),
732 }
733 }
734
735 /// Provide the state for the router.
with_state<S2>(self, state: S) -> MethodRouter<S2, B, E>736 pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, B, E> {
737 MethodRouter {
738 get: self.get.with_state(&state),
739 head: self.head.with_state(&state),
740 delete: self.delete.with_state(&state),
741 options: self.options.with_state(&state),
742 patch: self.patch.with_state(&state),
743 post: self.post.with_state(&state),
744 put: self.put.with_state(&state),
745 trace: self.trace.with_state(&state),
746 allow_header: self.allow_header,
747 fallback: self.fallback.with_state(state),
748 }
749 }
750
751 /// Chain an additional service that will accept requests matching the given
752 /// `MethodFilter`.
753 ///
754 /// # Example
755 ///
756 /// ```rust
757 /// use axum::{
758 /// http::Request,
759 /// Router,
760 /// routing::{MethodFilter, on_service},
761 /// };
762 /// use http::Response;
763 /// use std::convert::Infallible;
764 /// use hyper::Body;
765 ///
766 /// let service = tower::service_fn(|request: Request<Body>| async {
767 /// Ok::<_, Infallible>(Response::new(Body::empty()))
768 /// });
769 ///
770 /// // Requests to `DELETE /` will go to `service`
771 /// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));
772 /// # async {
773 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
774 /// # };
775 /// ```
776 #[track_caller]
on_service<T>(self, filter: MethodFilter, svc: T) -> Self where T: Service<Request<B>, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static,777 pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
778 where
779 T: Service<Request<B>, Error = E> + Clone + Send + 'static,
780 T::Response: IntoResponse + 'static,
781 T::Future: Send + 'static,
782 {
783 self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
784 }
785
786 #[track_caller]
on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, B, E>) -> Self787 fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, B, E>) -> Self {
788 // written as a separate function to generate less IR
789 #[track_caller]
790 fn set_endpoint<S, B, E>(
791 method_name: &str,
792 out: &mut MethodEndpoint<S, B, E>,
793 endpoint: &MethodEndpoint<S, B, E>,
794 endpoint_filter: MethodFilter,
795 filter: MethodFilter,
796 allow_header: &mut AllowHeader,
797 methods: &[&'static str],
798 ) where
799 MethodEndpoint<S, B, E>: Clone,
800 S: Clone,
801 {
802 if endpoint_filter.contains(filter) {
803 if out.is_some() {
804 panic!(
805 "Overlapping method route. Cannot add two method routes that both handle \
806 `{method_name}`",
807 )
808 }
809 *out = endpoint.clone();
810 for method in methods {
811 append_allow_header(allow_header, method);
812 }
813 }
814 }
815
816 set_endpoint(
817 "GET",
818 &mut self.get,
819 &endpoint,
820 filter,
821 MethodFilter::GET,
822 &mut self.allow_header,
823 &["GET", "HEAD"],
824 );
825
826 set_endpoint(
827 "HEAD",
828 &mut self.head,
829 &endpoint,
830 filter,
831 MethodFilter::HEAD,
832 &mut self.allow_header,
833 &["HEAD"],
834 );
835
836 set_endpoint(
837 "TRACE",
838 &mut self.trace,
839 &endpoint,
840 filter,
841 MethodFilter::TRACE,
842 &mut self.allow_header,
843 &["TRACE"],
844 );
845
846 set_endpoint(
847 "PUT",
848 &mut self.put,
849 &endpoint,
850 filter,
851 MethodFilter::PUT,
852 &mut self.allow_header,
853 &["PUT"],
854 );
855
856 set_endpoint(
857 "POST",
858 &mut self.post,
859 &endpoint,
860 filter,
861 MethodFilter::POST,
862 &mut self.allow_header,
863 &["POST"],
864 );
865
866 set_endpoint(
867 "PATCH",
868 &mut self.patch,
869 &endpoint,
870 filter,
871 MethodFilter::PATCH,
872 &mut self.allow_header,
873 &["PATCH"],
874 );
875
876 set_endpoint(
877 "OPTIONS",
878 &mut self.options,
879 &endpoint,
880 filter,
881 MethodFilter::OPTIONS,
882 &mut self.allow_header,
883 &["OPTIONS"],
884 );
885
886 set_endpoint(
887 "DELETE",
888 &mut self.delete,
889 &endpoint,
890 filter,
891 MethodFilter::DELETE,
892 &mut self.allow_header,
893 &["DELETE"],
894 );
895
896 self
897 }
898
899 chained_service_fn!(delete_service, DELETE);
900 chained_service_fn!(get_service, GET);
901 chained_service_fn!(head_service, HEAD);
902 chained_service_fn!(options_service, OPTIONS);
903 chained_service_fn!(patch_service, PATCH);
904 chained_service_fn!(post_service, POST);
905 chained_service_fn!(put_service, PUT);
906 chained_service_fn!(trace_service, TRACE);
907
908 #[doc = include_str!("../docs/method_routing/fallback.md")]
fallback_service<T>(mut self, svc: T) -> Self where T: Service<Request<B>, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static,909 pub fn fallback_service<T>(mut self, svc: T) -> Self
910 where
911 T: Service<Request<B>, Error = E> + Clone + Send + 'static,
912 T::Response: IntoResponse + 'static,
913 T::Future: Send + 'static,
914 {
915 self.fallback = Fallback::Service(Route::new(svc));
916 self
917 }
918
919 #[doc = include_str!("../docs/method_routing/layer.md")]
layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<S, NewReqBody, NewError> where L: Layer<Route<B, E>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, E: 'static, S: 'static, NewReqBody: HttpBody + 'static, NewError: 'static,920 pub fn layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<S, NewReqBody, NewError>
921 where
922 L: Layer<Route<B, E>> + Clone + Send + 'static,
923 L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
924 <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
925 <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static,
926 <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
927 E: 'static,
928 S: 'static,
929 NewReqBody: HttpBody + 'static,
930 NewError: 'static,
931 {
932 let layer_fn = move |route: Route<B, E>| route.layer(layer.clone());
933
934 MethodRouter {
935 get: self.get.map(layer_fn.clone()),
936 head: self.head.map(layer_fn.clone()),
937 delete: self.delete.map(layer_fn.clone()),
938 options: self.options.map(layer_fn.clone()),
939 patch: self.patch.map(layer_fn.clone()),
940 post: self.post.map(layer_fn.clone()),
941 put: self.put.map(layer_fn.clone()),
942 trace: self.trace.map(layer_fn.clone()),
943 fallback: self.fallback.map(layer_fn),
944 allow_header: self.allow_header,
945 }
946 }
947
948 #[doc = include_str!("../docs/method_routing/route_layer.md")]
949 #[track_caller]
route_layer<L>(mut self, layer: L) -> MethodRouter<S, B, E> where L: Layer<Route<B, E>> + Clone + Send + 'static, L::Service: Service<Request<B>, Error = E> + Clone + Send + 'static, <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<B>>>::Future: Send + 'static, E: 'static, S: 'static,950 pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, B, E>
951 where
952 L: Layer<Route<B, E>> + Clone + Send + 'static,
953 L::Service: Service<Request<B>, Error = E> + Clone + Send + 'static,
954 <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
955 <L::Service as Service<Request<B>>>::Future: Send + 'static,
956 E: 'static,
957 S: 'static,
958 {
959 if self.get.is_none()
960 && self.head.is_none()
961 && self.delete.is_none()
962 && self.options.is_none()
963 && self.patch.is_none()
964 && self.post.is_none()
965 && self.put.is_none()
966 && self.trace.is_none()
967 {
968 panic!(
969 "Adding a route_layer before any routes is a no-op. \
970 Add the routes you want the layer to apply to first."
971 );
972 }
973
974 let layer_fn = move |svc| {
975 let svc = layer.layer(svc);
976 let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
977 Route::new(svc)
978 };
979
980 self.get = self.get.map(layer_fn.clone());
981 self.head = self.head.map(layer_fn.clone());
982 self.delete = self.delete.map(layer_fn.clone());
983 self.options = self.options.map(layer_fn.clone());
984 self.patch = self.patch.map(layer_fn.clone());
985 self.post = self.post.map(layer_fn.clone());
986 self.put = self.put.map(layer_fn.clone());
987 self.trace = self.trace.map(layer_fn);
988
989 self
990 }
991
992 #[track_caller]
merge_for_path( mut self, path: Option<&str>, other: MethodRouter<S, B, E>, ) -> Self993 pub(crate) fn merge_for_path(
994 mut self,
995 path: Option<&str>,
996 other: MethodRouter<S, B, E>,
997 ) -> Self {
998 // written using inner functions to generate less IR
999 #[track_caller]
1000 fn merge_inner<S, B, E>(
1001 path: Option<&str>,
1002 name: &str,
1003 first: MethodEndpoint<S, B, E>,
1004 second: MethodEndpoint<S, B, E>,
1005 ) -> MethodEndpoint<S, B, E> {
1006 match (first, second) {
1007 (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
1008 (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
1009 _ => {
1010 if let Some(path) = path {
1011 panic!(
1012 "Overlapping method route. Handler for `{name} {path}` already exists"
1013 );
1014 } else {
1015 panic!(
1016 "Overlapping method route. Cannot merge two method routes that both \
1017 define `{name}`"
1018 );
1019 }
1020 }
1021 }
1022 }
1023
1024 self.get = merge_inner(path, "GET", self.get, other.get);
1025 self.head = merge_inner(path, "HEAD", self.head, other.head);
1026 self.delete = merge_inner(path, "DELETE", self.delete, other.delete);
1027 self.options = merge_inner(path, "OPTIONS", self.options, other.options);
1028 self.patch = merge_inner(path, "PATCH", self.patch, other.patch);
1029 self.post = merge_inner(path, "POST", self.post, other.post);
1030 self.put = merge_inner(path, "PUT", self.put, other.put);
1031 self.trace = merge_inner(path, "TRACE", self.trace, other.trace);
1032
1033 self.fallback = self
1034 .fallback
1035 .merge(other.fallback)
1036 .expect("Cannot merge two `MethodRouter`s that both have a fallback");
1037
1038 self.allow_header = self.allow_header.merge(other.allow_header);
1039
1040 self
1041 }
1042
1043 #[doc = include_str!("../docs/method_routing/merge.md")]
1044 #[track_caller]
merge(self, other: MethodRouter<S, B, E>) -> Self1045 pub fn merge(self, other: MethodRouter<S, B, E>) -> Self {
1046 self.merge_for_path(None, other)
1047 }
1048
1049 /// Apply a [`HandleErrorLayer`].
1050 ///
1051 /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
handle_error<F, T>(self, f: F) -> MethodRouter<S, B, Infallible> where F: Clone + Send + Sync + 'static, HandleError<Route<B, E>, F, T>: Service<Request<B>, Error = Infallible>, <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Future: Send, <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send, T: 'static, E: 'static, B: 'static, S: 'static,1052 pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, B, Infallible>
1053 where
1054 F: Clone + Send + Sync + 'static,
1055 HandleError<Route<B, E>, F, T>: Service<Request<B>, Error = Infallible>,
1056 <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Future: Send,
1057 <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
1058 T: 'static,
1059 E: 'static,
1060 B: 'static,
1061 S: 'static,
1062 {
1063 self.layer(HandleErrorLayer::new(f))
1064 }
1065
skip_allow_header(mut self) -> Self1066 fn skip_allow_header(mut self) -> Self {
1067 self.allow_header = AllowHeader::Skip;
1068 self
1069 }
1070
call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E>1071 pub(crate) fn call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E> {
1072 macro_rules! call {
1073 (
1074 $req:expr,
1075 $method:expr,
1076 $method_variant:ident,
1077 $svc:expr
1078 ) => {
1079 if $method == Method::$method_variant {
1080 match $svc {
1081 MethodEndpoint::None => {}
1082 MethodEndpoint::Route(route) => {
1083 return RouteFuture::from_future(route.oneshot_inner($req))
1084 .strip_body($method == Method::HEAD);
1085 }
1086 MethodEndpoint::BoxedHandler(handler) => {
1087 let mut route = handler.clone().into_route(state);
1088 return RouteFuture::from_future(route.oneshot_inner($req))
1089 .strip_body($method == Method::HEAD);
1090 }
1091 }
1092 }
1093 };
1094 }
1095
1096 let method = req.method().clone();
1097
1098 // written with a pattern match like this to ensure we call all routes
1099 let Self {
1100 get,
1101 head,
1102 delete,
1103 options,
1104 patch,
1105 post,
1106 put,
1107 trace,
1108 fallback,
1109 allow_header,
1110 } = self;
1111
1112 call!(req, method, HEAD, head);
1113 call!(req, method, HEAD, get);
1114 call!(req, method, GET, get);
1115 call!(req, method, POST, post);
1116 call!(req, method, OPTIONS, options);
1117 call!(req, method, PATCH, patch);
1118 call!(req, method, PUT, put);
1119 call!(req, method, DELETE, delete);
1120 call!(req, method, TRACE, trace);
1121
1122 let future = fallback.call_with_state(req, state);
1123
1124 match allow_header {
1125 AllowHeader::None => future.allow_header(Bytes::new()),
1126 AllowHeader::Skip => future,
1127 AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1128 }
1129 }
1130 }
1131
append_allow_header(allow_header: &mut AllowHeader, method: &'static str)1132 fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1133 match allow_header {
1134 AllowHeader::None => {
1135 *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1136 }
1137 AllowHeader::Skip => {}
1138 AllowHeader::Bytes(allow_header) => {
1139 if let Ok(s) = std::str::from_utf8(allow_header) {
1140 if !s.contains(method) {
1141 allow_header.extend_from_slice(b",");
1142 allow_header.extend_from_slice(method.as_bytes());
1143 }
1144 } else {
1145 #[cfg(debug_assertions)]
1146 panic!("`allow_header` contained invalid uft-8. This should never happen")
1147 }
1148 }
1149 }
1150 }
1151
1152 impl<S, B, E> Clone for MethodRouter<S, B, E> {
clone(&self) -> Self1153 fn clone(&self) -> Self {
1154 Self {
1155 get: self.get.clone(),
1156 head: self.head.clone(),
1157 delete: self.delete.clone(),
1158 options: self.options.clone(),
1159 patch: self.patch.clone(),
1160 post: self.post.clone(),
1161 put: self.put.clone(),
1162 trace: self.trace.clone(),
1163 fallback: self.fallback.clone(),
1164 allow_header: self.allow_header.clone(),
1165 }
1166 }
1167 }
1168
1169 impl<S, B, E> Default for MethodRouter<S, B, E>
1170 where
1171 B: HttpBody + Send + 'static,
1172 S: Clone,
1173 {
default() -> Self1174 fn default() -> Self {
1175 Self::new()
1176 }
1177 }
1178
1179 enum MethodEndpoint<S, B, E> {
1180 None,
1181 Route(Route<B, E>),
1182 BoxedHandler(BoxedIntoRoute<S, B, E>),
1183 }
1184
1185 impl<S, B, E> MethodEndpoint<S, B, E>
1186 where
1187 S: Clone,
1188 {
is_some(&self) -> bool1189 fn is_some(&self) -> bool {
1190 matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1191 }
1192
is_none(&self) -> bool1193 fn is_none(&self) -> bool {
1194 matches!(self, Self::None)
1195 }
1196
map<F, B2, E2>(self, f: F) -> MethodEndpoint<S, B2, E2> where S: 'static, B: 'static, E: 'static, F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static, B2: HttpBody + 'static, E2: 'static,1197 fn map<F, B2, E2>(self, f: F) -> MethodEndpoint<S, B2, E2>
1198 where
1199 S: 'static,
1200 B: 'static,
1201 E: 'static,
1202 F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
1203 B2: HttpBody + 'static,
1204 E2: 'static,
1205 {
1206 match self {
1207 Self::None => MethodEndpoint::None,
1208 Self::Route(route) => MethodEndpoint::Route(f(route)),
1209 Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1210 }
1211 }
1212
with_state<S2>(self, state: &S) -> MethodEndpoint<S2, B, E>1213 fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, B, E> {
1214 match self {
1215 MethodEndpoint::None => MethodEndpoint::None,
1216 MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1217 MethodEndpoint::BoxedHandler(handler) => {
1218 MethodEndpoint::Route(handler.into_route(state.clone()))
1219 }
1220 }
1221 }
1222 }
1223
1224 impl<S, B, E> Clone for MethodEndpoint<S, B, E> {
clone(&self) -> Self1225 fn clone(&self) -> Self {
1226 match self {
1227 Self::None => Self::None,
1228 Self::Route(inner) => Self::Route(inner.clone()),
1229 Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1230 }
1231 }
1232 }
1233
1234 impl<S, B, E> fmt::Debug for MethodEndpoint<S, B, E> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result1235 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1236 match self {
1237 Self::None => f.debug_tuple("None").finish(),
1238 Self::Route(inner) => inner.fmt(f),
1239 Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1240 }
1241 }
1242 }
1243
1244 impl<B, E> Service<Request<B>> for MethodRouter<(), B, E>
1245 where
1246 B: HttpBody + Send + 'static,
1247 {
1248 type Response = Response;
1249 type Error = E;
1250 type Future = RouteFuture<B, E>;
1251
1252 #[inline]
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>1253 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1254 Poll::Ready(Ok(()))
1255 }
1256
1257 #[inline]
call(&mut self, req: Request<B>) -> Self::Future1258 fn call(&mut self, req: Request<B>) -> Self::Future {
1259 self.call_with_state(req, ())
1260 }
1261 }
1262
1263 impl<S, B> Handler<(), S, B> for MethodRouter<S, B>
1264 where
1265 S: Clone + 'static,
1266 B: HttpBody + Send + 'static,
1267 {
1268 type Future = InfallibleRouteFuture<B>;
1269
call(mut self, req: Request<B>, state: S) -> Self::Future1270 fn call(mut self, req: Request<B>, state: S) -> Self::Future {
1271 InfallibleRouteFuture::new(self.call_with_state(req, state))
1272 }
1273 }
1274
1275 #[cfg(test)]
1276 mod tests {
1277 use super::*;
1278 use crate::{
1279 body::Body, error_handling::HandleErrorLayer, extract::State,
1280 handler::HandlerWithoutStateExt,
1281 };
1282 use axum_core::response::IntoResponse;
1283 use http::{header::ALLOW, HeaderMap};
1284 use std::time::Duration;
1285 use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt};
1286 use tower_http::{services::fs::ServeDir, validate_request::ValidateRequestHeaderLayer};
1287
1288 #[crate::test]
method_not_allowed_by_default()1289 async fn method_not_allowed_by_default() {
1290 let mut svc = MethodRouter::new();
1291 let (status, _, body) = call(Method::GET, &mut svc).await;
1292 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1293 assert!(body.is_empty());
1294 }
1295
1296 #[crate::test]
get_service_fn()1297 async fn get_service_fn() {
1298 async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
1299 Ok(Response::new(Body::from("ok")))
1300 }
1301
1302 let mut svc = get_service(service_fn(handle));
1303
1304 let (status, _, body) = call(Method::GET, &mut svc).await;
1305 assert_eq!(status, StatusCode::OK);
1306 assert_eq!(body, "ok");
1307 }
1308
1309 #[crate::test]
get_handler()1310 async fn get_handler() {
1311 let mut svc = MethodRouter::new().get(ok);
1312 let (status, _, body) = call(Method::GET, &mut svc).await;
1313 assert_eq!(status, StatusCode::OK);
1314 assert_eq!(body, "ok");
1315 }
1316
1317 #[crate::test]
get_accepts_head()1318 async fn get_accepts_head() {
1319 let mut svc = MethodRouter::new().get(ok);
1320 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1321 assert_eq!(status, StatusCode::OK);
1322 assert!(body.is_empty());
1323 }
1324
1325 #[crate::test]
head_takes_precedence_over_get()1326 async fn head_takes_precedence_over_get() {
1327 let mut svc = MethodRouter::new().head(created).get(ok);
1328 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1329 assert_eq!(status, StatusCode::CREATED);
1330 assert!(body.is_empty());
1331 }
1332
1333 #[crate::test]
merge()1334 async fn merge() {
1335 let mut svc = get(ok).merge(post(ok));
1336
1337 let (status, _, _) = call(Method::GET, &mut svc).await;
1338 assert_eq!(status, StatusCode::OK);
1339
1340 let (status, _, _) = call(Method::POST, &mut svc).await;
1341 assert_eq!(status, StatusCode::OK);
1342 }
1343
1344 #[crate::test]
layer()1345 async fn layer() {
1346 let mut svc = MethodRouter::new()
1347 .get(|| async { std::future::pending::<()>().await })
1348 .layer(ValidateRequestHeaderLayer::bearer("password"));
1349
1350 // method with route
1351 let (status, _, _) = call(Method::GET, &mut svc).await;
1352 assert_eq!(status, StatusCode::UNAUTHORIZED);
1353
1354 // method without route
1355 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1356 assert_eq!(status, StatusCode::UNAUTHORIZED);
1357 }
1358
1359 #[crate::test]
route_layer()1360 async fn route_layer() {
1361 let mut svc = MethodRouter::new()
1362 .get(|| async { std::future::pending::<()>().await })
1363 .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1364
1365 // method with route
1366 let (status, _, _) = call(Method::GET, &mut svc).await;
1367 assert_eq!(status, StatusCode::UNAUTHORIZED);
1368
1369 // method without route
1370 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1371 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1372 }
1373
1374 #[allow(dead_code)]
buiding_complex_router()1375 fn buiding_complex_router() {
1376 let app = crate::Router::new().route(
1377 "/",
1378 // use the all the things ️
1379 get(ok)
1380 .post(ok)
1381 .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1382 .merge(delete_service(ServeDir::new(".")))
1383 .fallback(|| async { StatusCode::NOT_FOUND })
1384 .put(ok)
1385 .layer(
1386 ServiceBuilder::new()
1387 .layer(HandleErrorLayer::new(|_| async {
1388 StatusCode::REQUEST_TIMEOUT
1389 }))
1390 .layer(TimeoutLayer::new(Duration::from_secs(10))),
1391 ),
1392 );
1393
1394 crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service());
1395 }
1396
1397 #[crate::test]
sets_allow_header()1398 async fn sets_allow_header() {
1399 let mut svc = MethodRouter::new().put(ok).patch(ok);
1400 let (status, headers, _) = call(Method::GET, &mut svc).await;
1401 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1402 assert_eq!(headers[ALLOW], "PUT,PATCH");
1403 }
1404
1405 #[crate::test]
sets_allow_header_get_head()1406 async fn sets_allow_header_get_head() {
1407 let mut svc = MethodRouter::new().get(ok).head(ok);
1408 let (status, headers, _) = call(Method::PUT, &mut svc).await;
1409 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1410 assert_eq!(headers[ALLOW], "GET,HEAD");
1411 }
1412
1413 #[crate::test]
empty_allow_header_by_default()1414 async fn empty_allow_header_by_default() {
1415 let mut svc = MethodRouter::new();
1416 let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1417 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1418 assert_eq!(headers[ALLOW], "");
1419 }
1420
1421 #[crate::test]
allow_header_when_merging()1422 async fn allow_header_when_merging() {
1423 let a = put(ok).patch(ok);
1424 let b = get(ok).head(ok);
1425 let mut svc = a.merge(b);
1426
1427 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1428 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1429 assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1430 }
1431
1432 #[crate::test]
allow_header_any()1433 async fn allow_header_any() {
1434 let mut svc = any(ok);
1435
1436 let (status, headers, _) = call(Method::GET, &mut svc).await;
1437 assert_eq!(status, StatusCode::OK);
1438 assert!(!headers.contains_key(ALLOW));
1439 }
1440
1441 #[crate::test]
allow_header_with_fallback()1442 async fn allow_header_with_fallback() {
1443 let mut svc = MethodRouter::new()
1444 .get(ok)
1445 .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1446
1447 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1448 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1449 assert_eq!(headers[ALLOW], "GET,HEAD");
1450 }
1451
1452 #[crate::test]
allow_header_with_fallback_that_sets_allow()1453 async fn allow_header_with_fallback_that_sets_allow() {
1454 async fn fallback(method: Method) -> Response {
1455 if method == Method::POST {
1456 "OK".into_response()
1457 } else {
1458 (
1459 StatusCode::METHOD_NOT_ALLOWED,
1460 [(ALLOW, "GET,POST")],
1461 "Method not allowed",
1462 )
1463 .into_response()
1464 }
1465 }
1466
1467 let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1468
1469 let (status, _, _) = call(Method::GET, &mut svc).await;
1470 assert_eq!(status, StatusCode::OK);
1471
1472 let (status, _, _) = call(Method::POST, &mut svc).await;
1473 assert_eq!(status, StatusCode::OK);
1474
1475 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1476 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1477 assert_eq!(headers[ALLOW], "GET,POST");
1478 }
1479
1480 #[crate::test]
allow_header_noop_middleware()1481 async fn allow_header_noop_middleware() {
1482 let mut svc = MethodRouter::new()
1483 .get(ok)
1484 .layer(tower::layer::util::Identity::new());
1485
1486 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1487 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1488 assert_eq!(headers[ALLOW], "GET,HEAD");
1489 }
1490
1491 #[crate::test]
1492 #[should_panic(
1493 expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1494 )]
handler_overlaps()1495 async fn handler_overlaps() {
1496 let _: MethodRouter<()> = get(ok).get(ok);
1497 }
1498
1499 #[crate::test]
1500 #[should_panic(
1501 expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1502 )]
service_overlaps()1503 async fn service_overlaps() {
1504 let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1505 }
1506
1507 #[crate::test]
get_head_does_not_overlap()1508 async fn get_head_does_not_overlap() {
1509 let _: MethodRouter<()> = get(ok).head(ok);
1510 }
1511
1512 #[crate::test]
head_get_does_not_overlap()1513 async fn head_get_does_not_overlap() {
1514 let _: MethodRouter<()> = head(ok).get(ok);
1515 }
1516
1517 #[crate::test]
accessing_state()1518 async fn accessing_state() {
1519 let mut svc = MethodRouter::new()
1520 .get(|State(state): State<&'static str>| async move { state })
1521 .with_state("state");
1522
1523 let (status, _, text) = call(Method::GET, &mut svc).await;
1524
1525 assert_eq!(status, StatusCode::OK);
1526 assert_eq!(text, "state");
1527 }
1528
1529 #[crate::test]
fallback_accessing_state()1530 async fn fallback_accessing_state() {
1531 let mut svc = MethodRouter::new()
1532 .fallback(|State(state): State<&'static str>| async move { state })
1533 .with_state("state");
1534
1535 let (status, _, text) = call(Method::GET, &mut svc).await;
1536
1537 assert_eq!(status, StatusCode::OK);
1538 assert_eq!(text, "state");
1539 }
1540
1541 #[crate::test]
merge_accessing_state()1542 async fn merge_accessing_state() {
1543 let one = get(|State(state): State<&'static str>| async move { state });
1544 let two = post(|State(state): State<&'static str>| async move { state });
1545
1546 let mut svc = one.merge(two).with_state("state");
1547
1548 let (status, _, text) = call(Method::GET, &mut svc).await;
1549 assert_eq!(status, StatusCode::OK);
1550 assert_eq!(text, "state");
1551
1552 let (status, _, _) = call(Method::POST, &mut svc).await;
1553 assert_eq!(status, StatusCode::OK);
1554 assert_eq!(text, "state");
1555 }
1556
call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) where S: Service<Request<Body>, Error = Infallible>, S::Response: IntoResponse,1557 async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1558 where
1559 S: Service<Request<Body>, Error = Infallible>,
1560 S::Response: IntoResponse,
1561 {
1562 let request = Request::builder()
1563 .uri("/")
1564 .method(method)
1565 .body(Body::empty())
1566 .unwrap();
1567 let response = svc
1568 .ready()
1569 .await
1570 .unwrap()
1571 .call(request)
1572 .await
1573 .unwrap()
1574 .into_response();
1575 let (parts, body) = response.into_parts();
1576 let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
1577 (parts.status, parts.headers, body)
1578 }
1579
ok() -> (StatusCode, &'static str)1580 async fn ok() -> (StatusCode, &'static str) {
1581 (StatusCode::OK, "ok")
1582 }
1583
created() -> (StatusCode, &'static str)1584 async fn created() -> (StatusCode, &'static str) {
1585 (StatusCode::CREATED, "created")
1586 }
1587 }
1588