1 use crate::{
2     body::{Bytes, Empty},
3     error_handling::HandleErrorLayer,
4     extract::{self, DefaultBodyLimit, FromRef, Path, State},
5     handler::{Handler, HandlerWithoutStateExt},
6     response::IntoResponse,
7     routing::{
8         delete, get, get_service, on, on_service, patch, patch_service,
9         path_router::path_for_nested_route, post, MethodFilter,
10     },
11     test_helpers::{
12         tracing_helpers::{capture_tracing, TracingEvent},
13         *,
14     },
15     BoxError, Extension, Json, Router,
16 };
17 use futures_util::stream::StreamExt;
18 use http::{
19     header::CONTENT_LENGTH,
20     header::{ALLOW, HOST},
21     HeaderMap, Method, Request, Response, StatusCode, Uri,
22 };
23 use hyper::Body;
24 use serde::Deserialize;
25 use serde_json::json;
26 use std::{
27     convert::Infallible,
28     future::{ready, Ready},
29     sync::atomic::{AtomicBool, AtomicUsize, Ordering},
30     task::{Context, Poll},
31     time::Duration,
32 };
33 use tower::{
34     service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt,
35 };
36 use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer};
37 use tower_service::Service;
38 
39 mod fallback;
40 mod get_to_head;
41 mod handle_error;
42 mod merge;
43 mod nest;
44 
45 #[crate::test]
hello_world()46 async fn hello_world() {
47     async fn root(_: Request<Body>) -> &'static str {
48         "Hello, World!"
49     }
50 
51     async fn foo(_: Request<Body>) -> &'static str {
52         "foo"
53     }
54 
55     async fn users_create(_: Request<Body>) -> &'static str {
56         "users#create"
57     }
58 
59     let app = Router::new()
60         .route("/", get(root).post(foo))
61         .route("/users", post(users_create));
62 
63     let client = TestClient::new(app);
64 
65     let res = client.get("/").send().await;
66     let body = res.text().await;
67     assert_eq!(body, "Hello, World!");
68 
69     let res = client.post("/").send().await;
70     let body = res.text().await;
71     assert_eq!(body, "foo");
72 
73     let res = client.post("/users").send().await;
74     let body = res.text().await;
75     assert_eq!(body, "users#create");
76 }
77 
78 #[crate::test]
routing()79 async fn routing() {
80     let app = Router::new()
81         .route(
82             "/users",
83             get(|_: Request<Body>| async { "users#index" })
84                 .post(|_: Request<Body>| async { "users#create" }),
85         )
86         .route("/users/:id", get(|_: Request<Body>| async { "users#show" }))
87         .route(
88             "/users/:id/action",
89             get(|_: Request<Body>| async { "users#action" }),
90         );
91 
92     let client = TestClient::new(app);
93 
94     let res = client.get("/").send().await;
95     assert_eq!(res.status(), StatusCode::NOT_FOUND);
96 
97     let res = client.get("/users").send().await;
98     assert_eq!(res.status(), StatusCode::OK);
99     assert_eq!(res.text().await, "users#index");
100 
101     let res = client.post("/users").send().await;
102     assert_eq!(res.status(), StatusCode::OK);
103     assert_eq!(res.text().await, "users#create");
104 
105     let res = client.get("/users/1").send().await;
106     assert_eq!(res.status(), StatusCode::OK);
107     assert_eq!(res.text().await, "users#show");
108 
109     let res = client.get("/users/1/action").send().await;
110     assert_eq!(res.status(), StatusCode::OK);
111     assert_eq!(res.text().await, "users#action");
112 }
113 
114 #[crate::test]
router_type_doesnt_change()115 async fn router_type_doesnt_change() {
116     let app: Router = Router::new()
117         .route(
118             "/",
119             on(MethodFilter::GET, |_: Request<Body>| async {
120                 "hi from GET"
121             })
122             .on(MethodFilter::POST, |_: Request<Body>| async {
123                 "hi from POST"
124             }),
125         )
126         .layer(tower_http::compression::CompressionLayer::new());
127 
128     let client = TestClient::new(app);
129 
130     let res = client.get("/").send().await;
131     assert_eq!(res.status(), StatusCode::OK);
132     assert_eq!(res.text().await, "hi from GET");
133 
134     let res = client.post("/").send().await;
135     assert_eq!(res.status(), StatusCode::OK);
136     assert_eq!(res.text().await, "hi from POST");
137 }
138 
139 #[crate::test]
routing_between_services()140 async fn routing_between_services() {
141     use std::convert::Infallible;
142     use tower::service_fn;
143 
144     async fn handle(_: Request<Body>) -> &'static str {
145         "handler"
146     }
147 
148     let app = Router::new()
149         .route(
150             "/one",
151             get_service(service_fn(|_: Request<Body>| async {
152                 Ok::<_, Infallible>(Response::new(Body::from("one get")))
153             }))
154             .post_service(service_fn(|_: Request<Body>| async {
155                 Ok::<_, Infallible>(Response::new(Body::from("one post")))
156             }))
157             .on_service(
158                 MethodFilter::PUT,
159                 service_fn(|_: Request<Body>| async {
160                     Ok::<_, Infallible>(Response::new(Body::from("one put")))
161                 }),
162             ),
163         )
164         .route("/two", on_service(MethodFilter::GET, handle.into_service()));
165 
166     let client = TestClient::new(app);
167 
168     let res = client.get("/one").send().await;
169     assert_eq!(res.status(), StatusCode::OK);
170     assert_eq!(res.text().await, "one get");
171 
172     let res = client.post("/one").send().await;
173     assert_eq!(res.status(), StatusCode::OK);
174     assert_eq!(res.text().await, "one post");
175 
176     let res = client.put("/one").send().await;
177     assert_eq!(res.status(), StatusCode::OK);
178     assert_eq!(res.text().await, "one put");
179 
180     let res = client.get("/two").send().await;
181     assert_eq!(res.status(), StatusCode::OK);
182     assert_eq!(res.text().await, "handler");
183 }
184 
185 #[crate::test]
middleware_on_single_route()186 async fn middleware_on_single_route() {
187     use tower::ServiceBuilder;
188     use tower_http::{compression::CompressionLayer, trace::TraceLayer};
189 
190     async fn handle(_: Request<Body>) -> &'static str {
191         "Hello, World!"
192     }
193 
194     let app = Router::new().route(
195         "/",
196         get(handle.layer(
197             ServiceBuilder::new()
198                 .layer(TraceLayer::new_for_http())
199                 .layer(CompressionLayer::new())
200                 .into_inner(),
201         )),
202     );
203 
204     let client = TestClient::new(app);
205 
206     let res = client.get("/").send().await;
207     let body = res.text().await;
208 
209     assert_eq!(body, "Hello, World!");
210 }
211 
212 #[crate::test]
service_in_bottom()213 async fn service_in_bottom() {
214     async fn handler(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
215         Ok(Response::new(hyper::Body::empty()))
216     }
217 
218     let app = Router::new().route("/", get_service(service_fn(handler)));
219 
220     TestClient::new(app);
221 }
222 
223 #[crate::test]
wrong_method_handler()224 async fn wrong_method_handler() {
225     let app = Router::new()
226         .route("/", get(|| async {}).post(|| async {}))
227         .route("/foo", patch(|| async {}));
228 
229     let client = TestClient::new(app);
230 
231     let res = client.patch("/").send().await;
232     assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
233     assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST");
234 
235     let res = client.patch("/foo").send().await;
236     assert_eq!(res.status(), StatusCode::OK);
237 
238     let res = client.post("/foo").send().await;
239     assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
240     assert_eq!(res.headers()[ALLOW], "PATCH");
241 
242     let res = client.get("/bar").send().await;
243     assert_eq!(res.status(), StatusCode::NOT_FOUND);
244 }
245 
246 #[crate::test]
wrong_method_service()247 async fn wrong_method_service() {
248     #[derive(Clone)]
249     struct Svc;
250 
251     impl<R> Service<R> for Svc {
252         type Response = Response<Empty<Bytes>>;
253         type Error = Infallible;
254         type Future = Ready<Result<Self::Response, Self::Error>>;
255 
256         fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
257             Poll::Ready(Ok(()))
258         }
259 
260         fn call(&mut self, _req: R) -> Self::Future {
261             ready(Ok(Response::new(Empty::new())))
262         }
263     }
264 
265     let app = Router::new()
266         .route("/", get_service(Svc).post_service(Svc))
267         .route("/foo", patch_service(Svc));
268 
269     let client = TestClient::new(app);
270 
271     let res = client.patch("/").send().await;
272     assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
273     assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST");
274 
275     let res = client.patch("/foo").send().await;
276     assert_eq!(res.status(), StatusCode::OK);
277 
278     let res = client.post("/foo").send().await;
279     assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
280     assert_eq!(res.headers()[ALLOW], "PATCH");
281 
282     let res = client.get("/bar").send().await;
283     assert_eq!(res.status(), StatusCode::NOT_FOUND);
284 }
285 
286 #[crate::test]
multiple_methods_for_one_handler()287 async fn multiple_methods_for_one_handler() {
288     async fn root(_: Request<Body>) -> &'static str {
289         "Hello, World!"
290     }
291 
292     let app = Router::new().route("/", on(MethodFilter::GET | MethodFilter::POST, root));
293 
294     let client = TestClient::new(app);
295 
296     let res = client.get("/").send().await;
297     assert_eq!(res.status(), StatusCode::OK);
298 
299     let res = client.post("/").send().await;
300     assert_eq!(res.status(), StatusCode::OK);
301 }
302 
303 #[crate::test]
wildcard_sees_whole_url()304 async fn wildcard_sees_whole_url() {
305     let app = Router::new().route("/api/*rest", get(|uri: Uri| async move { uri.to_string() }));
306 
307     let client = TestClient::new(app);
308 
309     let res = client.get("/api/foo/bar").send().await;
310     assert_eq!(res.text().await, "/api/foo/bar");
311 }
312 
313 #[crate::test]
middleware_applies_to_routes_above()314 async fn middleware_applies_to_routes_above() {
315     let app = Router::new()
316         .route("/one", get(std::future::pending::<()>))
317         .layer(
318             ServiceBuilder::new()
319                 .layer(HandleErrorLayer::new(|_: BoxError| async move {
320                     StatusCode::REQUEST_TIMEOUT
321                 }))
322                 .layer(TimeoutLayer::new(Duration::new(0, 0))),
323         )
324         .route("/two", get(|| async {}));
325 
326     let client = TestClient::new(app);
327 
328     let res = client.get("/one").send().await;
329     assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
330 
331     let res = client.get("/two").send().await;
332     assert_eq!(res.status(), StatusCode::OK);
333 }
334 
335 #[crate::test]
not_found_for_extra_trailing_slash()336 async fn not_found_for_extra_trailing_slash() {
337     let app = Router::new().route("/foo", get(|| async {}));
338 
339     let client = TestClient::new(app);
340 
341     let res = client.get("/foo/").send().await;
342     assert_eq!(res.status(), StatusCode::NOT_FOUND);
343 
344     let res = client.get("/foo").send().await;
345     assert_eq!(res.status(), StatusCode::OK);
346 }
347 
348 #[crate::test]
not_found_for_missing_trailing_slash()349 async fn not_found_for_missing_trailing_slash() {
350     let app = Router::new().route("/foo/", get(|| async {}));
351 
352     let client = TestClient::new(app);
353 
354     let res = client.get("/foo").send().await;
355     assert_eq!(res.status(), StatusCode::NOT_FOUND);
356 }
357 
358 #[crate::test]
with_and_without_trailing_slash()359 async fn with_and_without_trailing_slash() {
360     let app = Router::new()
361         .route("/foo", get(|| async { "without tsr" }))
362         .route("/foo/", get(|| async { "with tsr" }));
363 
364     let client = TestClient::new(app);
365 
366     let res = client.get("/foo/").send().await;
367     assert_eq!(res.status(), StatusCode::OK);
368     assert_eq!(res.text().await, "with tsr");
369 
370     let res = client.get("/foo").send().await;
371     assert_eq!(res.status(), StatusCode::OK);
372     assert_eq!(res.text().await, "without tsr");
373 }
374 
375 // for https://github.com/tokio-rs/axum/issues/420
376 #[crate::test]
wildcard_doesnt_match_just_trailing_slash()377 async fn wildcard_doesnt_match_just_trailing_slash() {
378     let app = Router::new().route(
379         "/x/*path",
380         get(|Path(path): Path<String>| async move { path }),
381     );
382 
383     let client = TestClient::new(app);
384 
385     let res = client.get("/x").send().await;
386     assert_eq!(res.status(), StatusCode::NOT_FOUND);
387 
388     let res = client.get("/x/").send().await;
389     assert_eq!(res.status(), StatusCode::NOT_FOUND);
390 
391     let res = client.get("/x/foo/bar").send().await;
392     assert_eq!(res.status(), StatusCode::OK);
393     assert_eq!(res.text().await, "foo/bar");
394 }
395 
396 #[crate::test]
what_matches_wildcard()397 async fn what_matches_wildcard() {
398     let app = Router::new()
399         .route("/*key", get(|| async { "root" }))
400         .route("/x/*key", get(|| async { "x" }))
401         .fallback(|| async { "fallback" });
402 
403     let client = TestClient::new(app);
404 
405     let get = |path| {
406         let f = client.get(path).send();
407         async move { f.await.text().await }
408     };
409 
410     assert_eq!(get("/").await, "fallback");
411     assert_eq!(get("/a").await, "root");
412     assert_eq!(get("/a/").await, "root");
413     assert_eq!(get("/a/b").await, "root");
414     assert_eq!(get("/a/b/").await, "root");
415 
416     assert_eq!(get("/x").await, "root");
417     assert_eq!(get("/x/").await, "root");
418     assert_eq!(get("/x/a").await, "x");
419     assert_eq!(get("/x/a/").await, "x");
420     assert_eq!(get("/x/a/b").await, "x");
421     assert_eq!(get("/x/a/b/").await, "x");
422 }
423 
424 #[crate::test]
static_and_dynamic_paths()425 async fn static_and_dynamic_paths() {
426     let app = Router::new()
427         .route(
428             "/:key",
429             get(|Path(key): Path<String>| async move { format!("dynamic: {key}") }),
430         )
431         .route("/foo", get(|| async { "static" }));
432 
433     let client = TestClient::new(app);
434 
435     let res = client.get("/bar").send().await;
436     assert_eq!(res.text().await, "dynamic: bar");
437 
438     let res = client.get("/foo").send().await;
439     assert_eq!(res.text().await, "static");
440 }
441 
442 #[crate::test]
443 #[should_panic(expected = "Paths must start with a `/`. Use \"/\" for root routes")]
empty_route()444 async fn empty_route() {
445     let app = Router::new().route("", get(|| async {}));
446     TestClient::new(app);
447 }
448 
449 #[crate::test]
middleware_still_run_for_unmatched_requests()450 async fn middleware_still_run_for_unmatched_requests() {
451     #[derive(Clone)]
452     struct CountMiddleware<S>(S);
453 
454     static COUNT: AtomicUsize = AtomicUsize::new(0);
455 
456     impl<R, S> Service<R> for CountMiddleware<S>
457     where
458         S: Service<R>,
459     {
460         type Response = S::Response;
461         type Error = S::Error;
462         type Future = S::Future;
463 
464         fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
465             self.0.poll_ready(cx)
466         }
467 
468         fn call(&mut self, req: R) -> Self::Future {
469             COUNT.fetch_add(1, Ordering::SeqCst);
470             self.0.call(req)
471         }
472     }
473 
474     let app = Router::new()
475         .route("/", get(|| async {}))
476         .layer(tower::layer::layer_fn(CountMiddleware));
477 
478     let client = TestClient::new(app);
479 
480     assert_eq!(COUNT.load(Ordering::SeqCst), 0);
481 
482     client.get("/").send().await;
483     assert_eq!(COUNT.load(Ordering::SeqCst), 1);
484 
485     client.get("/not-found").send().await;
486     assert_eq!(COUNT.load(Ordering::SeqCst), 2);
487 }
488 
489 #[crate::test]
490 #[should_panic(expected = "\
491     Invalid route: `Router::route_service` cannot be used with `Router`s. \
492     Use `Router::nest` instead\
493 ")]
routing_to_router_panics()494 async fn routing_to_router_panics() {
495     TestClient::new(Router::new().route_service("/", Router::new()));
496 }
497 
498 #[crate::test]
route_layer()499 async fn route_layer() {
500     let app = Router::new()
501         .route("/foo", get(|| async {}))
502         .route_layer(ValidateRequestHeaderLayer::bearer("password"));
503 
504     let client = TestClient::new(app);
505 
506     let res = client
507         .get("/foo")
508         .header("authorization", "Bearer password")
509         .send()
510         .await;
511     assert_eq!(res.status(), StatusCode::OK);
512 
513     let res = client.get("/foo").send().await;
514     assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
515 
516     let res = client.get("/not-found").send().await;
517     assert_eq!(res.status(), StatusCode::NOT_FOUND);
518 
519     // it would be nice if this would return `405 Method Not Allowed`
520     // but that requires knowing more about which method route we're calling, which we
521     // don't know currently since its just a generic `Service`
522     let res = client.post("/foo").send().await;
523     assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
524 }
525 
526 #[crate::test]
different_methods_added_in_different_routes()527 async fn different_methods_added_in_different_routes() {
528     let app = Router::new()
529         .route("/", get(|| async { "GET" }))
530         .route("/", post(|| async { "POST" }));
531 
532     let client = TestClient::new(app);
533 
534     let res = client.get("/").send().await;
535     let body = res.text().await;
536     assert_eq!(body, "GET");
537 
538     let res = client.post("/").send().await;
539     let body = res.text().await;
540     assert_eq!(body, "POST");
541 }
542 
543 #[crate::test]
544 #[should_panic(expected = "Cannot merge two `Router`s that both have a fallback")]
merging_routers_with_fallbacks_panics()545 async fn merging_routers_with_fallbacks_panics() {
546     async fn fallback() {}
547     let one = Router::new().fallback(fallback);
548     let two = Router::new().fallback(fallback);
549     TestClient::new(one.merge(two));
550 }
551 
552 #[test]
553 #[should_panic(expected = "Overlapping method route. Handler for `GET /foo/bar` already exists")]
routes_with_overlapping_method_routes()554 fn routes_with_overlapping_method_routes() {
555     async fn handler() {}
556     let _: Router = Router::new()
557         .route("/foo/bar", get(handler))
558         .route("/foo/bar", get(handler));
559 }
560 
561 #[test]
562 #[should_panic(expected = "Overlapping method route. Handler for `GET /foo/bar` already exists")]
merging_with_overlapping_method_routes()563 fn merging_with_overlapping_method_routes() {
564     async fn handler() {}
565     let app: Router = Router::new().route("/foo/bar", get(handler));
566     _ = app.clone().merge(app);
567 }
568 
569 #[crate::test]
merging_routers_with_same_paths_but_different_methods()570 async fn merging_routers_with_same_paths_but_different_methods() {
571     let one = Router::new().route("/", get(|| async { "GET" }));
572     let two = Router::new().route("/", post(|| async { "POST" }));
573 
574     let client = TestClient::new(one.merge(two));
575 
576     let res = client.get("/").send().await;
577     let body = res.text().await;
578     assert_eq!(body, "GET");
579 
580     let res = client.post("/").send().await;
581     let body = res.text().await;
582     assert_eq!(body, "POST");
583 }
584 
585 #[crate::test]
head_content_length_through_hyper_server()586 async fn head_content_length_through_hyper_server() {
587     let app = Router::new()
588         .route("/", get(|| async { "foo" }))
589         .route("/json", get(|| async { Json(json!({ "foo": 1 })) }));
590 
591     let client = TestClient::new(app);
592 
593     let res = client.head("/").send().await;
594     assert_eq!(res.headers()["content-length"], "3");
595     assert!(res.text().await.is_empty());
596 
597     let res = client.head("/json").send().await;
598     assert_eq!(res.headers()["content-length"], "9");
599     assert!(res.text().await.is_empty());
600 }
601 
602 #[crate::test]
head_content_length_through_hyper_server_that_hits_fallback()603 async fn head_content_length_through_hyper_server_that_hits_fallback() {
604     let app = Router::new().fallback(|| async { "foo" });
605 
606     let client = TestClient::new(app);
607 
608     let res = client.head("/").send().await;
609     assert_eq!(res.headers()["content-length"], "3");
610 }
611 
612 #[crate::test]
head_with_middleware_applied()613 async fn head_with_middleware_applied() {
614     use tower_http::compression::{predicate::SizeAbove, CompressionLayer};
615 
616     let app = Router::new()
617         .nest(
618             "/",
619             Router::new().route("/", get(|| async { "Hello, World!" })),
620         )
621         .layer(CompressionLayer::new().compress_when(SizeAbove::new(0)));
622 
623     let client = TestClient::new(app);
624 
625     // send GET request
626     let res = client
627         .get("/")
628         .header("accept-encoding", "gzip")
629         .send()
630         .await;
631     assert_eq!(res.headers()["transfer-encoding"], "chunked");
632     // cannot have `transfer-encoding: chunked` and `content-length`
633     assert!(!res.headers().contains_key("content-length"));
634 
635     // send HEAD request
636     let res = client
637         .head("/")
638         .header("accept-encoding", "gzip")
639         .send()
640         .await;
641     // no response body so no `transfer-encoding`
642     assert!(!res.headers().contains_key("transfer-encoding"));
643     // no content-length since we cannot know it since the response
644     // is compressed
645     assert!(!res.headers().contains_key("content-length"));
646 }
647 
648 #[crate::test]
649 #[should_panic(expected = "Paths must start with a `/`")]
routes_must_start_with_slash()650 async fn routes_must_start_with_slash() {
651     let app = Router::new().route(":foo", get(|| async {}));
652     TestClient::new(app);
653 }
654 
655 #[crate::test]
body_limited_by_default()656 async fn body_limited_by_default() {
657     let app = Router::new()
658         .route("/bytes", post(|_: Bytes| async {}))
659         .route("/string", post(|_: String| async {}))
660         .route("/json", post(|_: Json<serde_json::Value>| async {}));
661 
662     let client = TestClient::new(app);
663 
664     for uri in ["/bytes", "/string", "/json"] {
665         println!("calling {uri}");
666 
667         let stream = futures_util::stream::repeat("a".repeat(1000)).map(Ok::<_, hyper::Error>);
668         let body = Body::wrap_stream(stream);
669 
670         let res_future = client
671             .post(uri)
672             .header("content-type", "application/json")
673             .body(body)
674             .send();
675         let res = tokio::time::timeout(Duration::from_secs(3), res_future)
676             .await
677             .expect("never got response");
678 
679         assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
680     }
681 }
682 
683 #[crate::test]
disabling_the_default_limit()684 async fn disabling_the_default_limit() {
685     let app = Router::new()
686         .route("/", post(|_: Bytes| async {}))
687         .layer(DefaultBodyLimit::disable());
688 
689     let client = TestClient::new(app);
690 
691     // `DEFAULT_LIMIT` is 2mb so make a body larger than that
692     let body = Body::from("a".repeat(3_000_000));
693 
694     let res = client.post("/").body(body).send().await;
695 
696     assert_eq!(res.status(), StatusCode::OK);
697 }
698 
699 #[crate::test]
limited_body_with_content_length()700 async fn limited_body_with_content_length() {
701     const LIMIT: usize = 3;
702 
703     let app = Router::new()
704         .route(
705             "/",
706             post(|headers: HeaderMap, _body: Bytes| async move {
707                 assert!(headers.get(CONTENT_LENGTH).is_some());
708             }),
709         )
710         .layer(RequestBodyLimitLayer::new(LIMIT));
711 
712     let client = TestClient::new(app);
713 
714     let res = client.post("/").body("a".repeat(LIMIT)).send().await;
715     assert_eq!(res.status(), StatusCode::OK);
716 
717     let res = client.post("/").body("a".repeat(LIMIT * 2)).send().await;
718     assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
719 }
720 
721 #[crate::test]
changing_the_default_limit()722 async fn changing_the_default_limit() {
723     let new_limit = 2;
724 
725     let app = Router::new()
726         .route("/", post(|_: Bytes| async {}))
727         .layer(DefaultBodyLimit::max(new_limit));
728 
729     let client = TestClient::new(app);
730 
731     let res = client
732         .post("/")
733         .body(Body::from("a".repeat(new_limit)))
734         .send()
735         .await;
736     assert_eq!(res.status(), StatusCode::OK);
737 
738     let res = client
739         .post("/")
740         .body(Body::from("a".repeat(new_limit + 1)))
741         .send()
742         .await;
743     assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
744 }
745 
746 #[crate::test]
limited_body_with_streaming_body()747 async fn limited_body_with_streaming_body() {
748     const LIMIT: usize = 3;
749 
750     let app = Router::new()
751         .route(
752             "/",
753             post(|headers: HeaderMap, _body: Bytes| async move {
754                 assert!(headers.get(CONTENT_LENGTH).is_none());
755             }),
756         )
757         .layer(RequestBodyLimitLayer::new(LIMIT));
758 
759     let client = TestClient::new(app);
760 
761     let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT))]);
762     let res = client
763         .post("/")
764         .body(Body::wrap_stream(stream))
765         .send()
766         .await;
767     assert_eq!(res.status(), StatusCode::OK);
768 
769     let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT * 2))]);
770     let res = client
771         .post("/")
772         .body(Body::wrap_stream(stream))
773         .send()
774         .await;
775     assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
776 }
777 
778 #[crate::test]
extract_state()779 async fn extract_state() {
780     #[derive(Clone)]
781     struct AppState {
782         value: i32,
783         inner: InnerState,
784     }
785 
786     #[derive(Clone)]
787     struct InnerState {
788         value: i32,
789     }
790 
791     impl FromRef<AppState> for InnerState {
792         fn from_ref(state: &AppState) -> Self {
793             state.inner.clone()
794         }
795     }
796 
797     async fn handler(State(outer): State<AppState>, State(inner): State<InnerState>) {
798         assert_eq!(outer.value, 1);
799         assert_eq!(inner.value, 2);
800     }
801 
802     let state = AppState {
803         value: 1,
804         inner: InnerState { value: 2 },
805     };
806 
807     let app = Router::new().route("/", get(handler)).with_state(state);
808     let client = TestClient::new(app);
809 
810     let res = client.get("/").send().await;
811     assert_eq!(res.status(), StatusCode::OK);
812 }
813 
814 #[crate::test]
explicitly_set_state()815 async fn explicitly_set_state() {
816     let app = Router::new()
817         .route_service(
818             "/",
819             get(|State(state): State<&'static str>| async move { state }).with_state("foo"),
820         )
821         .with_state("...");
822 
823     let client = TestClient::new(app);
824     let res = client.get("/").send().await;
825     assert_eq!(res.text().await, "foo");
826 }
827 
828 #[crate::test]
layer_response_into_response()829 async fn layer_response_into_response() {
830     fn map_response<B>(_res: Response<B>) -> Result<Response<B>, impl IntoResponse> {
831         let headers = [("x-foo", "bar")];
832         let status = StatusCode::IM_A_TEAPOT;
833         Err((headers, status))
834     }
835 
836     let app = Router::new()
837         .route("/", get(|| async {}))
838         .layer(MapResponseLayer::new(map_response));
839 
840     let client = TestClient::new(app);
841 
842     let res = client.get("/").send().await;
843     assert_eq!(res.headers()["x-foo"], "bar");
844     assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
845 }
846 
847 #[allow(dead_code)]
method_router_fallback_with_state()848 fn method_router_fallback_with_state() {
849     async fn fallback(_: State<&'static str>) {}
850 
851     async fn not_found(_: State<&'static str>) {}
852 
853     let state = "foo";
854 
855     let _: Router = Router::new()
856         .fallback(get(fallback).fallback(not_found))
857         .with_state(state);
858 }
859 
860 #[test]
test_path_for_nested_route()861 fn test_path_for_nested_route() {
862     assert_eq!(path_for_nested_route("/", "/"), "/");
863 
864     assert_eq!(path_for_nested_route("/a", "/"), "/a");
865     assert_eq!(path_for_nested_route("/", "/b"), "/b");
866     assert_eq!(path_for_nested_route("/a/", "/"), "/a/");
867     assert_eq!(path_for_nested_route("/", "/b/"), "/b/");
868 
869     assert_eq!(path_for_nested_route("/a", "/b"), "/a/b");
870     assert_eq!(path_for_nested_route("/a/", "/b"), "/a/b");
871     assert_eq!(path_for_nested_route("/a", "/b/"), "/a/b/");
872     assert_eq!(path_for_nested_route("/a/", "/b/"), "/a/b/");
873 }
874 
875 #[crate::test]
state_isnt_cloned_too_much()876 async fn state_isnt_cloned_too_much() {
877     static SETUP_DONE: AtomicBool = AtomicBool::new(false);
878     static COUNT: AtomicUsize = AtomicUsize::new(0);
879 
880     struct AppState;
881 
882     impl Clone for AppState {
883         fn clone(&self) -> Self {
884             #[rustversion::since(1.65)]
885             #[track_caller]
886             fn count() {
887                 if SETUP_DONE.load(Ordering::SeqCst) {
888                     let bt = std::backtrace::Backtrace::force_capture();
889                     let bt = bt
890                         .to_string()
891                         .lines()
892                         .filter(|line| line.contains("axum") || line.contains("./src"))
893                         .collect::<Vec<_>>()
894                         .join("\n");
895                     println!("AppState::Clone:\n===============\n{}\n", bt);
896                     COUNT.fetch_add(1, Ordering::SeqCst);
897                 }
898             }
899 
900             #[rustversion::not(since(1.65))]
901             fn count() {
902                 if SETUP_DONE.load(Ordering::SeqCst) {
903                     COUNT.fetch_add(1, Ordering::SeqCst);
904                 }
905             }
906 
907             count();
908 
909             Self
910         }
911     }
912 
913     let app = Router::new()
914         .route("/", get(|_: State<AppState>| async {}))
915         .with_state(AppState);
916 
917     let client = TestClient::new(app);
918 
919     // ignore clones made during setup
920     SETUP_DONE.store(true, Ordering::SeqCst);
921 
922     client.get("/").send().await;
923 
924     assert_eq!(COUNT.load(Ordering::SeqCst), 4);
925 }
926 
927 #[crate::test]
logging_rejections()928 async fn logging_rejections() {
929     #[derive(Deserialize, Eq, PartialEq, Debug)]
930     #[serde(deny_unknown_fields)]
931     struct RejectionEvent {
932         message: String,
933         status: u16,
934         body: String,
935         rejection_type: String,
936     }
937 
938     let events = capture_tracing::<RejectionEvent, _, _>(|| async {
939         let app = Router::new()
940             .route("/extension", get(|_: Extension<Infallible>| async {}))
941             .route("/string", post(|_: String| async {}));
942 
943         let client = TestClient::new(app);
944 
945         assert_eq!(
946             client.get("/extension").send().await.status(),
947             StatusCode::INTERNAL_SERVER_ERROR
948         );
949 
950         assert_eq!(
951             client
952                 .post("/string")
953                 .body(Vec::from([0, 159, 146, 150]))
954                 .send()
955                 .await
956                 .status(),
957             StatusCode::BAD_REQUEST,
958         );
959     })
960     .await;
961 
962     assert_eq!(
963         dbg!(events),
964         Vec::from([
965             TracingEvent {
966                 fields: RejectionEvent {
967                     message: "rejecting request".to_owned(),
968                     status: 500,
969                     body: "Missing request extension: Extension of \
970                         type `core::convert::Infallible` was not found. \
971                         Perhaps you forgot to add it? See `axum::Extension`."
972                         .to_owned(),
973                     rejection_type: "axum::extract::rejection::MissingExtension".to_owned(),
974                 },
975                 target: "axum::rejection".to_owned(),
976                 level: "TRACE".to_owned(),
977             },
978             TracingEvent {
979                 fields: RejectionEvent {
980                     message: "rejecting request".to_owned(),
981                     status: 400,
982                     body: "Request body didn't contain valid UTF-8: \
983                         invalid utf-8 sequence of 1 bytes from index 1"
984                         .to_owned(),
985                     rejection_type: "axum_core::extract::rejection::InvalidUtf8".to_owned(),
986                 },
987                 target: "axum::rejection".to_owned(),
988                 level: "TRACE".to_owned(),
989             },
990         ])
991     )
992 }
993 
994 // https://github.com/tokio-rs/axum/issues/1955
995 #[crate::test]
connect_going_to_custom_fallback()996 async fn connect_going_to_custom_fallback() {
997     let app = Router::new().fallback(|| async { (StatusCode::NOT_FOUND, "custom fallback") });
998 
999     let req = Request::builder()
1000         .uri("example.com:443")
1001         .method(Method::CONNECT)
1002         .header(HOST, "example.com:443")
1003         .body(Body::empty())
1004         .unwrap();
1005 
1006     let res = app.oneshot(req).await.unwrap();
1007     assert_eq!(res.status(), StatusCode::NOT_FOUND);
1008     let text = String::from_utf8(hyper::body::to_bytes(res).await.unwrap().to_vec()).unwrap();
1009     assert_eq!(text, "custom fallback");
1010 }
1011 
1012 // https://github.com/tokio-rs/axum/issues/1955
1013 #[crate::test]
connect_going_to_default_fallback()1014 async fn connect_going_to_default_fallback() {
1015     let app = Router::new();
1016 
1017     let req = Request::builder()
1018         .uri("example.com:443")
1019         .method(Method::CONNECT)
1020         .header(HOST, "example.com:443")
1021         .body(Body::empty())
1022         .unwrap();
1023 
1024     let res = app.oneshot(req).await.unwrap();
1025     assert_eq!(res.status(), StatusCode::NOT_FOUND);
1026     let body = hyper::body::to_bytes(res).await.unwrap();
1027     assert!(body.is_empty());
1028 }
1029 
1030 #[crate::test]
impl_handler_for_into_response()1031 async fn impl_handler_for_into_response() {
1032     let app = Router::new().route("/things", post((StatusCode::CREATED, "thing created")));
1033 
1034     let client = TestClient::new(app);
1035 
1036     let res = client.post("/things").send().await;
1037     assert_eq!(res.status(), StatusCode::CREATED);
1038     assert_eq!(res.text().await, "thing created");
1039 }
1040