1 use crate::{
2 body::{Bytes, Empty},
3 error_handling::HandleErrorLayer,
4 extract::{self, DefaultBodyLimit, FromRef, Path, State},
5 handler::{Handler, HandlerWithoutStateExt},
6 response::IntoResponse,
7 routing::{
8 delete, get, get_service, on, on_service, patch, patch_service,
9 path_router::path_for_nested_route, post, MethodFilter,
10 },
11 test_helpers::{
12 tracing_helpers::{capture_tracing, TracingEvent},
13 *,
14 },
15 BoxError, Extension, Json, Router,
16 };
17 use futures_util::stream::StreamExt;
18 use http::{
19 header::CONTENT_LENGTH,
20 header::{ALLOW, HOST},
21 HeaderMap, Method, Request, Response, StatusCode, Uri,
22 };
23 use hyper::Body;
24 use serde::Deserialize;
25 use serde_json::json;
26 use std::{
27 convert::Infallible,
28 future::{ready, Ready},
29 sync::atomic::{AtomicBool, AtomicUsize, Ordering},
30 task::{Context, Poll},
31 time::Duration,
32 };
33 use tower::{
34 service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt,
35 };
36 use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer};
37 use tower_service::Service;
38
39 mod fallback;
40 mod get_to_head;
41 mod handle_error;
42 mod merge;
43 mod nest;
44
45 #[crate::test]
hello_world()46 async fn hello_world() {
47 async fn root(_: Request<Body>) -> &'static str {
48 "Hello, World!"
49 }
50
51 async fn foo(_: Request<Body>) -> &'static str {
52 "foo"
53 }
54
55 async fn users_create(_: Request<Body>) -> &'static str {
56 "users#create"
57 }
58
59 let app = Router::new()
60 .route("/", get(root).post(foo))
61 .route("/users", post(users_create));
62
63 let client = TestClient::new(app);
64
65 let res = client.get("/").send().await;
66 let body = res.text().await;
67 assert_eq!(body, "Hello, World!");
68
69 let res = client.post("/").send().await;
70 let body = res.text().await;
71 assert_eq!(body, "foo");
72
73 let res = client.post("/users").send().await;
74 let body = res.text().await;
75 assert_eq!(body, "users#create");
76 }
77
78 #[crate::test]
routing()79 async fn routing() {
80 let app = Router::new()
81 .route(
82 "/users",
83 get(|_: Request<Body>| async { "users#index" })
84 .post(|_: Request<Body>| async { "users#create" }),
85 )
86 .route("/users/:id", get(|_: Request<Body>| async { "users#show" }))
87 .route(
88 "/users/:id/action",
89 get(|_: Request<Body>| async { "users#action" }),
90 );
91
92 let client = TestClient::new(app);
93
94 let res = client.get("/").send().await;
95 assert_eq!(res.status(), StatusCode::NOT_FOUND);
96
97 let res = client.get("/users").send().await;
98 assert_eq!(res.status(), StatusCode::OK);
99 assert_eq!(res.text().await, "users#index");
100
101 let res = client.post("/users").send().await;
102 assert_eq!(res.status(), StatusCode::OK);
103 assert_eq!(res.text().await, "users#create");
104
105 let res = client.get("/users/1").send().await;
106 assert_eq!(res.status(), StatusCode::OK);
107 assert_eq!(res.text().await, "users#show");
108
109 let res = client.get("/users/1/action").send().await;
110 assert_eq!(res.status(), StatusCode::OK);
111 assert_eq!(res.text().await, "users#action");
112 }
113
114 #[crate::test]
router_type_doesnt_change()115 async fn router_type_doesnt_change() {
116 let app: Router = Router::new()
117 .route(
118 "/",
119 on(MethodFilter::GET, |_: Request<Body>| async {
120 "hi from GET"
121 })
122 .on(MethodFilter::POST, |_: Request<Body>| async {
123 "hi from POST"
124 }),
125 )
126 .layer(tower_http::compression::CompressionLayer::new());
127
128 let client = TestClient::new(app);
129
130 let res = client.get("/").send().await;
131 assert_eq!(res.status(), StatusCode::OK);
132 assert_eq!(res.text().await, "hi from GET");
133
134 let res = client.post("/").send().await;
135 assert_eq!(res.status(), StatusCode::OK);
136 assert_eq!(res.text().await, "hi from POST");
137 }
138
139 #[crate::test]
routing_between_services()140 async fn routing_between_services() {
141 use std::convert::Infallible;
142 use tower::service_fn;
143
144 async fn handle(_: Request<Body>) -> &'static str {
145 "handler"
146 }
147
148 let app = Router::new()
149 .route(
150 "/one",
151 get_service(service_fn(|_: Request<Body>| async {
152 Ok::<_, Infallible>(Response::new(Body::from("one get")))
153 }))
154 .post_service(service_fn(|_: Request<Body>| async {
155 Ok::<_, Infallible>(Response::new(Body::from("one post")))
156 }))
157 .on_service(
158 MethodFilter::PUT,
159 service_fn(|_: Request<Body>| async {
160 Ok::<_, Infallible>(Response::new(Body::from("one put")))
161 }),
162 ),
163 )
164 .route("/two", on_service(MethodFilter::GET, handle.into_service()));
165
166 let client = TestClient::new(app);
167
168 let res = client.get("/one").send().await;
169 assert_eq!(res.status(), StatusCode::OK);
170 assert_eq!(res.text().await, "one get");
171
172 let res = client.post("/one").send().await;
173 assert_eq!(res.status(), StatusCode::OK);
174 assert_eq!(res.text().await, "one post");
175
176 let res = client.put("/one").send().await;
177 assert_eq!(res.status(), StatusCode::OK);
178 assert_eq!(res.text().await, "one put");
179
180 let res = client.get("/two").send().await;
181 assert_eq!(res.status(), StatusCode::OK);
182 assert_eq!(res.text().await, "handler");
183 }
184
185 #[crate::test]
middleware_on_single_route()186 async fn middleware_on_single_route() {
187 use tower::ServiceBuilder;
188 use tower_http::{compression::CompressionLayer, trace::TraceLayer};
189
190 async fn handle(_: Request<Body>) -> &'static str {
191 "Hello, World!"
192 }
193
194 let app = Router::new().route(
195 "/",
196 get(handle.layer(
197 ServiceBuilder::new()
198 .layer(TraceLayer::new_for_http())
199 .layer(CompressionLayer::new())
200 .into_inner(),
201 )),
202 );
203
204 let client = TestClient::new(app);
205
206 let res = client.get("/").send().await;
207 let body = res.text().await;
208
209 assert_eq!(body, "Hello, World!");
210 }
211
212 #[crate::test]
service_in_bottom()213 async fn service_in_bottom() {
214 async fn handler(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
215 Ok(Response::new(hyper::Body::empty()))
216 }
217
218 let app = Router::new().route("/", get_service(service_fn(handler)));
219
220 TestClient::new(app);
221 }
222
223 #[crate::test]
wrong_method_handler()224 async fn wrong_method_handler() {
225 let app = Router::new()
226 .route("/", get(|| async {}).post(|| async {}))
227 .route("/foo", patch(|| async {}));
228
229 let client = TestClient::new(app);
230
231 let res = client.patch("/").send().await;
232 assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
233 assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST");
234
235 let res = client.patch("/foo").send().await;
236 assert_eq!(res.status(), StatusCode::OK);
237
238 let res = client.post("/foo").send().await;
239 assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
240 assert_eq!(res.headers()[ALLOW], "PATCH");
241
242 let res = client.get("/bar").send().await;
243 assert_eq!(res.status(), StatusCode::NOT_FOUND);
244 }
245
246 #[crate::test]
wrong_method_service()247 async fn wrong_method_service() {
248 #[derive(Clone)]
249 struct Svc;
250
251 impl<R> Service<R> for Svc {
252 type Response = Response<Empty<Bytes>>;
253 type Error = Infallible;
254 type Future = Ready<Result<Self::Response, Self::Error>>;
255
256 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
257 Poll::Ready(Ok(()))
258 }
259
260 fn call(&mut self, _req: R) -> Self::Future {
261 ready(Ok(Response::new(Empty::new())))
262 }
263 }
264
265 let app = Router::new()
266 .route("/", get_service(Svc).post_service(Svc))
267 .route("/foo", patch_service(Svc));
268
269 let client = TestClient::new(app);
270
271 let res = client.patch("/").send().await;
272 assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
273 assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST");
274
275 let res = client.patch("/foo").send().await;
276 assert_eq!(res.status(), StatusCode::OK);
277
278 let res = client.post("/foo").send().await;
279 assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
280 assert_eq!(res.headers()[ALLOW], "PATCH");
281
282 let res = client.get("/bar").send().await;
283 assert_eq!(res.status(), StatusCode::NOT_FOUND);
284 }
285
286 #[crate::test]
multiple_methods_for_one_handler()287 async fn multiple_methods_for_one_handler() {
288 async fn root(_: Request<Body>) -> &'static str {
289 "Hello, World!"
290 }
291
292 let app = Router::new().route("/", on(MethodFilter::GET | MethodFilter::POST, root));
293
294 let client = TestClient::new(app);
295
296 let res = client.get("/").send().await;
297 assert_eq!(res.status(), StatusCode::OK);
298
299 let res = client.post("/").send().await;
300 assert_eq!(res.status(), StatusCode::OK);
301 }
302
303 #[crate::test]
wildcard_sees_whole_url()304 async fn wildcard_sees_whole_url() {
305 let app = Router::new().route("/api/*rest", get(|uri: Uri| async move { uri.to_string() }));
306
307 let client = TestClient::new(app);
308
309 let res = client.get("/api/foo/bar").send().await;
310 assert_eq!(res.text().await, "/api/foo/bar");
311 }
312
313 #[crate::test]
middleware_applies_to_routes_above()314 async fn middleware_applies_to_routes_above() {
315 let app = Router::new()
316 .route("/one", get(std::future::pending::<()>))
317 .layer(
318 ServiceBuilder::new()
319 .layer(HandleErrorLayer::new(|_: BoxError| async move {
320 StatusCode::REQUEST_TIMEOUT
321 }))
322 .layer(TimeoutLayer::new(Duration::new(0, 0))),
323 )
324 .route("/two", get(|| async {}));
325
326 let client = TestClient::new(app);
327
328 let res = client.get("/one").send().await;
329 assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
330
331 let res = client.get("/two").send().await;
332 assert_eq!(res.status(), StatusCode::OK);
333 }
334
335 #[crate::test]
not_found_for_extra_trailing_slash()336 async fn not_found_for_extra_trailing_slash() {
337 let app = Router::new().route("/foo", get(|| async {}));
338
339 let client = TestClient::new(app);
340
341 let res = client.get("/foo/").send().await;
342 assert_eq!(res.status(), StatusCode::NOT_FOUND);
343
344 let res = client.get("/foo").send().await;
345 assert_eq!(res.status(), StatusCode::OK);
346 }
347
348 #[crate::test]
not_found_for_missing_trailing_slash()349 async fn not_found_for_missing_trailing_slash() {
350 let app = Router::new().route("/foo/", get(|| async {}));
351
352 let client = TestClient::new(app);
353
354 let res = client.get("/foo").send().await;
355 assert_eq!(res.status(), StatusCode::NOT_FOUND);
356 }
357
358 #[crate::test]
with_and_without_trailing_slash()359 async fn with_and_without_trailing_slash() {
360 let app = Router::new()
361 .route("/foo", get(|| async { "without tsr" }))
362 .route("/foo/", get(|| async { "with tsr" }));
363
364 let client = TestClient::new(app);
365
366 let res = client.get("/foo/").send().await;
367 assert_eq!(res.status(), StatusCode::OK);
368 assert_eq!(res.text().await, "with tsr");
369
370 let res = client.get("/foo").send().await;
371 assert_eq!(res.status(), StatusCode::OK);
372 assert_eq!(res.text().await, "without tsr");
373 }
374
375 // for https://github.com/tokio-rs/axum/issues/420
376 #[crate::test]
wildcard_doesnt_match_just_trailing_slash()377 async fn wildcard_doesnt_match_just_trailing_slash() {
378 let app = Router::new().route(
379 "/x/*path",
380 get(|Path(path): Path<String>| async move { path }),
381 );
382
383 let client = TestClient::new(app);
384
385 let res = client.get("/x").send().await;
386 assert_eq!(res.status(), StatusCode::NOT_FOUND);
387
388 let res = client.get("/x/").send().await;
389 assert_eq!(res.status(), StatusCode::NOT_FOUND);
390
391 let res = client.get("/x/foo/bar").send().await;
392 assert_eq!(res.status(), StatusCode::OK);
393 assert_eq!(res.text().await, "foo/bar");
394 }
395
396 #[crate::test]
what_matches_wildcard()397 async fn what_matches_wildcard() {
398 let app = Router::new()
399 .route("/*key", get(|| async { "root" }))
400 .route("/x/*key", get(|| async { "x" }))
401 .fallback(|| async { "fallback" });
402
403 let client = TestClient::new(app);
404
405 let get = |path| {
406 let f = client.get(path).send();
407 async move { f.await.text().await }
408 };
409
410 assert_eq!(get("/").await, "fallback");
411 assert_eq!(get("/a").await, "root");
412 assert_eq!(get("/a/").await, "root");
413 assert_eq!(get("/a/b").await, "root");
414 assert_eq!(get("/a/b/").await, "root");
415
416 assert_eq!(get("/x").await, "root");
417 assert_eq!(get("/x/").await, "root");
418 assert_eq!(get("/x/a").await, "x");
419 assert_eq!(get("/x/a/").await, "x");
420 assert_eq!(get("/x/a/b").await, "x");
421 assert_eq!(get("/x/a/b/").await, "x");
422 }
423
424 #[crate::test]
static_and_dynamic_paths()425 async fn static_and_dynamic_paths() {
426 let app = Router::new()
427 .route(
428 "/:key",
429 get(|Path(key): Path<String>| async move { format!("dynamic: {key}") }),
430 )
431 .route("/foo", get(|| async { "static" }));
432
433 let client = TestClient::new(app);
434
435 let res = client.get("/bar").send().await;
436 assert_eq!(res.text().await, "dynamic: bar");
437
438 let res = client.get("/foo").send().await;
439 assert_eq!(res.text().await, "static");
440 }
441
442 #[crate::test]
443 #[should_panic(expected = "Paths must start with a `/`. Use \"/\" for root routes")]
empty_route()444 async fn empty_route() {
445 let app = Router::new().route("", get(|| async {}));
446 TestClient::new(app);
447 }
448
449 #[crate::test]
middleware_still_run_for_unmatched_requests()450 async fn middleware_still_run_for_unmatched_requests() {
451 #[derive(Clone)]
452 struct CountMiddleware<S>(S);
453
454 static COUNT: AtomicUsize = AtomicUsize::new(0);
455
456 impl<R, S> Service<R> for CountMiddleware<S>
457 where
458 S: Service<R>,
459 {
460 type Response = S::Response;
461 type Error = S::Error;
462 type Future = S::Future;
463
464 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
465 self.0.poll_ready(cx)
466 }
467
468 fn call(&mut self, req: R) -> Self::Future {
469 COUNT.fetch_add(1, Ordering::SeqCst);
470 self.0.call(req)
471 }
472 }
473
474 let app = Router::new()
475 .route("/", get(|| async {}))
476 .layer(tower::layer::layer_fn(CountMiddleware));
477
478 let client = TestClient::new(app);
479
480 assert_eq!(COUNT.load(Ordering::SeqCst), 0);
481
482 client.get("/").send().await;
483 assert_eq!(COUNT.load(Ordering::SeqCst), 1);
484
485 client.get("/not-found").send().await;
486 assert_eq!(COUNT.load(Ordering::SeqCst), 2);
487 }
488
489 #[crate::test]
490 #[should_panic(expected = "\
491 Invalid route: `Router::route_service` cannot be used with `Router`s. \
492 Use `Router::nest` instead\
493 ")]
routing_to_router_panics()494 async fn routing_to_router_panics() {
495 TestClient::new(Router::new().route_service("/", Router::new()));
496 }
497
498 #[crate::test]
route_layer()499 async fn route_layer() {
500 let app = Router::new()
501 .route("/foo", get(|| async {}))
502 .route_layer(ValidateRequestHeaderLayer::bearer("password"));
503
504 let client = TestClient::new(app);
505
506 let res = client
507 .get("/foo")
508 .header("authorization", "Bearer password")
509 .send()
510 .await;
511 assert_eq!(res.status(), StatusCode::OK);
512
513 let res = client.get("/foo").send().await;
514 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
515
516 let res = client.get("/not-found").send().await;
517 assert_eq!(res.status(), StatusCode::NOT_FOUND);
518
519 // it would be nice if this would return `405 Method Not Allowed`
520 // but that requires knowing more about which method route we're calling, which we
521 // don't know currently since its just a generic `Service`
522 let res = client.post("/foo").send().await;
523 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
524 }
525
526 #[crate::test]
different_methods_added_in_different_routes()527 async fn different_methods_added_in_different_routes() {
528 let app = Router::new()
529 .route("/", get(|| async { "GET" }))
530 .route("/", post(|| async { "POST" }));
531
532 let client = TestClient::new(app);
533
534 let res = client.get("/").send().await;
535 let body = res.text().await;
536 assert_eq!(body, "GET");
537
538 let res = client.post("/").send().await;
539 let body = res.text().await;
540 assert_eq!(body, "POST");
541 }
542
543 #[crate::test]
544 #[should_panic(expected = "Cannot merge two `Router`s that both have a fallback")]
merging_routers_with_fallbacks_panics()545 async fn merging_routers_with_fallbacks_panics() {
546 async fn fallback() {}
547 let one = Router::new().fallback(fallback);
548 let two = Router::new().fallback(fallback);
549 TestClient::new(one.merge(two));
550 }
551
552 #[test]
553 #[should_panic(expected = "Overlapping method route. Handler for `GET /foo/bar` already exists")]
routes_with_overlapping_method_routes()554 fn routes_with_overlapping_method_routes() {
555 async fn handler() {}
556 let _: Router = Router::new()
557 .route("/foo/bar", get(handler))
558 .route("/foo/bar", get(handler));
559 }
560
561 #[test]
562 #[should_panic(expected = "Overlapping method route. Handler for `GET /foo/bar` already exists")]
merging_with_overlapping_method_routes()563 fn merging_with_overlapping_method_routes() {
564 async fn handler() {}
565 let app: Router = Router::new().route("/foo/bar", get(handler));
566 _ = app.clone().merge(app);
567 }
568
569 #[crate::test]
merging_routers_with_same_paths_but_different_methods()570 async fn merging_routers_with_same_paths_but_different_methods() {
571 let one = Router::new().route("/", get(|| async { "GET" }));
572 let two = Router::new().route("/", post(|| async { "POST" }));
573
574 let client = TestClient::new(one.merge(two));
575
576 let res = client.get("/").send().await;
577 let body = res.text().await;
578 assert_eq!(body, "GET");
579
580 let res = client.post("/").send().await;
581 let body = res.text().await;
582 assert_eq!(body, "POST");
583 }
584
585 #[crate::test]
head_content_length_through_hyper_server()586 async fn head_content_length_through_hyper_server() {
587 let app = Router::new()
588 .route("/", get(|| async { "foo" }))
589 .route("/json", get(|| async { Json(json!({ "foo": 1 })) }));
590
591 let client = TestClient::new(app);
592
593 let res = client.head("/").send().await;
594 assert_eq!(res.headers()["content-length"], "3");
595 assert!(res.text().await.is_empty());
596
597 let res = client.head("/json").send().await;
598 assert_eq!(res.headers()["content-length"], "9");
599 assert!(res.text().await.is_empty());
600 }
601
602 #[crate::test]
head_content_length_through_hyper_server_that_hits_fallback()603 async fn head_content_length_through_hyper_server_that_hits_fallback() {
604 let app = Router::new().fallback(|| async { "foo" });
605
606 let client = TestClient::new(app);
607
608 let res = client.head("/").send().await;
609 assert_eq!(res.headers()["content-length"], "3");
610 }
611
612 #[crate::test]
head_with_middleware_applied()613 async fn head_with_middleware_applied() {
614 use tower_http::compression::{predicate::SizeAbove, CompressionLayer};
615
616 let app = Router::new()
617 .nest(
618 "/",
619 Router::new().route("/", get(|| async { "Hello, World!" })),
620 )
621 .layer(CompressionLayer::new().compress_when(SizeAbove::new(0)));
622
623 let client = TestClient::new(app);
624
625 // send GET request
626 let res = client
627 .get("/")
628 .header("accept-encoding", "gzip")
629 .send()
630 .await;
631 assert_eq!(res.headers()["transfer-encoding"], "chunked");
632 // cannot have `transfer-encoding: chunked` and `content-length`
633 assert!(!res.headers().contains_key("content-length"));
634
635 // send HEAD request
636 let res = client
637 .head("/")
638 .header("accept-encoding", "gzip")
639 .send()
640 .await;
641 // no response body so no `transfer-encoding`
642 assert!(!res.headers().contains_key("transfer-encoding"));
643 // no content-length since we cannot know it since the response
644 // is compressed
645 assert!(!res.headers().contains_key("content-length"));
646 }
647
648 #[crate::test]
649 #[should_panic(expected = "Paths must start with a `/`")]
routes_must_start_with_slash()650 async fn routes_must_start_with_slash() {
651 let app = Router::new().route(":foo", get(|| async {}));
652 TestClient::new(app);
653 }
654
655 #[crate::test]
body_limited_by_default()656 async fn body_limited_by_default() {
657 let app = Router::new()
658 .route("/bytes", post(|_: Bytes| async {}))
659 .route("/string", post(|_: String| async {}))
660 .route("/json", post(|_: Json<serde_json::Value>| async {}));
661
662 let client = TestClient::new(app);
663
664 for uri in ["/bytes", "/string", "/json"] {
665 println!("calling {uri}");
666
667 let stream = futures_util::stream::repeat("a".repeat(1000)).map(Ok::<_, hyper::Error>);
668 let body = Body::wrap_stream(stream);
669
670 let res_future = client
671 .post(uri)
672 .header("content-type", "application/json")
673 .body(body)
674 .send();
675 let res = tokio::time::timeout(Duration::from_secs(3), res_future)
676 .await
677 .expect("never got response");
678
679 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
680 }
681 }
682
683 #[crate::test]
disabling_the_default_limit()684 async fn disabling_the_default_limit() {
685 let app = Router::new()
686 .route("/", post(|_: Bytes| async {}))
687 .layer(DefaultBodyLimit::disable());
688
689 let client = TestClient::new(app);
690
691 // `DEFAULT_LIMIT` is 2mb so make a body larger than that
692 let body = Body::from("a".repeat(3_000_000));
693
694 let res = client.post("/").body(body).send().await;
695
696 assert_eq!(res.status(), StatusCode::OK);
697 }
698
699 #[crate::test]
limited_body_with_content_length()700 async fn limited_body_with_content_length() {
701 const LIMIT: usize = 3;
702
703 let app = Router::new()
704 .route(
705 "/",
706 post(|headers: HeaderMap, _body: Bytes| async move {
707 assert!(headers.get(CONTENT_LENGTH).is_some());
708 }),
709 )
710 .layer(RequestBodyLimitLayer::new(LIMIT));
711
712 let client = TestClient::new(app);
713
714 let res = client.post("/").body("a".repeat(LIMIT)).send().await;
715 assert_eq!(res.status(), StatusCode::OK);
716
717 let res = client.post("/").body("a".repeat(LIMIT * 2)).send().await;
718 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
719 }
720
721 #[crate::test]
changing_the_default_limit()722 async fn changing_the_default_limit() {
723 let new_limit = 2;
724
725 let app = Router::new()
726 .route("/", post(|_: Bytes| async {}))
727 .layer(DefaultBodyLimit::max(new_limit));
728
729 let client = TestClient::new(app);
730
731 let res = client
732 .post("/")
733 .body(Body::from("a".repeat(new_limit)))
734 .send()
735 .await;
736 assert_eq!(res.status(), StatusCode::OK);
737
738 let res = client
739 .post("/")
740 .body(Body::from("a".repeat(new_limit + 1)))
741 .send()
742 .await;
743 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
744 }
745
746 #[crate::test]
limited_body_with_streaming_body()747 async fn limited_body_with_streaming_body() {
748 const LIMIT: usize = 3;
749
750 let app = Router::new()
751 .route(
752 "/",
753 post(|headers: HeaderMap, _body: Bytes| async move {
754 assert!(headers.get(CONTENT_LENGTH).is_none());
755 }),
756 )
757 .layer(RequestBodyLimitLayer::new(LIMIT));
758
759 let client = TestClient::new(app);
760
761 let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT))]);
762 let res = client
763 .post("/")
764 .body(Body::wrap_stream(stream))
765 .send()
766 .await;
767 assert_eq!(res.status(), StatusCode::OK);
768
769 let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT * 2))]);
770 let res = client
771 .post("/")
772 .body(Body::wrap_stream(stream))
773 .send()
774 .await;
775 assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
776 }
777
778 #[crate::test]
extract_state()779 async fn extract_state() {
780 #[derive(Clone)]
781 struct AppState {
782 value: i32,
783 inner: InnerState,
784 }
785
786 #[derive(Clone)]
787 struct InnerState {
788 value: i32,
789 }
790
791 impl FromRef<AppState> for InnerState {
792 fn from_ref(state: &AppState) -> Self {
793 state.inner.clone()
794 }
795 }
796
797 async fn handler(State(outer): State<AppState>, State(inner): State<InnerState>) {
798 assert_eq!(outer.value, 1);
799 assert_eq!(inner.value, 2);
800 }
801
802 let state = AppState {
803 value: 1,
804 inner: InnerState { value: 2 },
805 };
806
807 let app = Router::new().route("/", get(handler)).with_state(state);
808 let client = TestClient::new(app);
809
810 let res = client.get("/").send().await;
811 assert_eq!(res.status(), StatusCode::OK);
812 }
813
814 #[crate::test]
explicitly_set_state()815 async fn explicitly_set_state() {
816 let app = Router::new()
817 .route_service(
818 "/",
819 get(|State(state): State<&'static str>| async move { state }).with_state("foo"),
820 )
821 .with_state("...");
822
823 let client = TestClient::new(app);
824 let res = client.get("/").send().await;
825 assert_eq!(res.text().await, "foo");
826 }
827
828 #[crate::test]
layer_response_into_response()829 async fn layer_response_into_response() {
830 fn map_response<B>(_res: Response<B>) -> Result<Response<B>, impl IntoResponse> {
831 let headers = [("x-foo", "bar")];
832 let status = StatusCode::IM_A_TEAPOT;
833 Err((headers, status))
834 }
835
836 let app = Router::new()
837 .route("/", get(|| async {}))
838 .layer(MapResponseLayer::new(map_response));
839
840 let client = TestClient::new(app);
841
842 let res = client.get("/").send().await;
843 assert_eq!(res.headers()["x-foo"], "bar");
844 assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
845 }
846
847 #[allow(dead_code)]
method_router_fallback_with_state()848 fn method_router_fallback_with_state() {
849 async fn fallback(_: State<&'static str>) {}
850
851 async fn not_found(_: State<&'static str>) {}
852
853 let state = "foo";
854
855 let _: Router = Router::new()
856 .fallback(get(fallback).fallback(not_found))
857 .with_state(state);
858 }
859
860 #[test]
test_path_for_nested_route()861 fn test_path_for_nested_route() {
862 assert_eq!(path_for_nested_route("/", "/"), "/");
863
864 assert_eq!(path_for_nested_route("/a", "/"), "/a");
865 assert_eq!(path_for_nested_route("/", "/b"), "/b");
866 assert_eq!(path_for_nested_route("/a/", "/"), "/a/");
867 assert_eq!(path_for_nested_route("/", "/b/"), "/b/");
868
869 assert_eq!(path_for_nested_route("/a", "/b"), "/a/b");
870 assert_eq!(path_for_nested_route("/a/", "/b"), "/a/b");
871 assert_eq!(path_for_nested_route("/a", "/b/"), "/a/b/");
872 assert_eq!(path_for_nested_route("/a/", "/b/"), "/a/b/");
873 }
874
875 #[crate::test]
state_isnt_cloned_too_much()876 async fn state_isnt_cloned_too_much() {
877 static SETUP_DONE: AtomicBool = AtomicBool::new(false);
878 static COUNT: AtomicUsize = AtomicUsize::new(0);
879
880 struct AppState;
881
882 impl Clone for AppState {
883 fn clone(&self) -> Self {
884 #[rustversion::since(1.65)]
885 #[track_caller]
886 fn count() {
887 if SETUP_DONE.load(Ordering::SeqCst) {
888 let bt = std::backtrace::Backtrace::force_capture();
889 let bt = bt
890 .to_string()
891 .lines()
892 .filter(|line| line.contains("axum") || line.contains("./src"))
893 .collect::<Vec<_>>()
894 .join("\n");
895 println!("AppState::Clone:\n===============\n{}\n", bt);
896 COUNT.fetch_add(1, Ordering::SeqCst);
897 }
898 }
899
900 #[rustversion::not(since(1.65))]
901 fn count() {
902 if SETUP_DONE.load(Ordering::SeqCst) {
903 COUNT.fetch_add(1, Ordering::SeqCst);
904 }
905 }
906
907 count();
908
909 Self
910 }
911 }
912
913 let app = Router::new()
914 .route("/", get(|_: State<AppState>| async {}))
915 .with_state(AppState);
916
917 let client = TestClient::new(app);
918
919 // ignore clones made during setup
920 SETUP_DONE.store(true, Ordering::SeqCst);
921
922 client.get("/").send().await;
923
924 assert_eq!(COUNT.load(Ordering::SeqCst), 4);
925 }
926
927 #[crate::test]
logging_rejections()928 async fn logging_rejections() {
929 #[derive(Deserialize, Eq, PartialEq, Debug)]
930 #[serde(deny_unknown_fields)]
931 struct RejectionEvent {
932 message: String,
933 status: u16,
934 body: String,
935 rejection_type: String,
936 }
937
938 let events = capture_tracing::<RejectionEvent, _, _>(|| async {
939 let app = Router::new()
940 .route("/extension", get(|_: Extension<Infallible>| async {}))
941 .route("/string", post(|_: String| async {}));
942
943 let client = TestClient::new(app);
944
945 assert_eq!(
946 client.get("/extension").send().await.status(),
947 StatusCode::INTERNAL_SERVER_ERROR
948 );
949
950 assert_eq!(
951 client
952 .post("/string")
953 .body(Vec::from([0, 159, 146, 150]))
954 .send()
955 .await
956 .status(),
957 StatusCode::BAD_REQUEST,
958 );
959 })
960 .await;
961
962 assert_eq!(
963 dbg!(events),
964 Vec::from([
965 TracingEvent {
966 fields: RejectionEvent {
967 message: "rejecting request".to_owned(),
968 status: 500,
969 body: "Missing request extension: Extension of \
970 type `core::convert::Infallible` was not found. \
971 Perhaps you forgot to add it? See `axum::Extension`."
972 .to_owned(),
973 rejection_type: "axum::extract::rejection::MissingExtension".to_owned(),
974 },
975 target: "axum::rejection".to_owned(),
976 level: "TRACE".to_owned(),
977 },
978 TracingEvent {
979 fields: RejectionEvent {
980 message: "rejecting request".to_owned(),
981 status: 400,
982 body: "Request body didn't contain valid UTF-8: \
983 invalid utf-8 sequence of 1 bytes from index 1"
984 .to_owned(),
985 rejection_type: "axum_core::extract::rejection::InvalidUtf8".to_owned(),
986 },
987 target: "axum::rejection".to_owned(),
988 level: "TRACE".to_owned(),
989 },
990 ])
991 )
992 }
993
994 // https://github.com/tokio-rs/axum/issues/1955
995 #[crate::test]
connect_going_to_custom_fallback()996 async fn connect_going_to_custom_fallback() {
997 let app = Router::new().fallback(|| async { (StatusCode::NOT_FOUND, "custom fallback") });
998
999 let req = Request::builder()
1000 .uri("example.com:443")
1001 .method(Method::CONNECT)
1002 .header(HOST, "example.com:443")
1003 .body(Body::empty())
1004 .unwrap();
1005
1006 let res = app.oneshot(req).await.unwrap();
1007 assert_eq!(res.status(), StatusCode::NOT_FOUND);
1008 let text = String::from_utf8(hyper::body::to_bytes(res).await.unwrap().to_vec()).unwrap();
1009 assert_eq!(text, "custom fallback");
1010 }
1011
1012 // https://github.com/tokio-rs/axum/issues/1955
1013 #[crate::test]
connect_going_to_default_fallback()1014 async fn connect_going_to_default_fallback() {
1015 let app = Router::new();
1016
1017 let req = Request::builder()
1018 .uri("example.com:443")
1019 .method(Method::CONNECT)
1020 .header(HOST, "example.com:443")
1021 .body(Body::empty())
1022 .unwrap();
1023
1024 let res = app.oneshot(req).await.unwrap();
1025 assert_eq!(res.status(), StatusCode::NOT_FOUND);
1026 let body = hyper::body::to_bytes(res).await.unwrap();
1027 assert!(body.is_empty());
1028 }
1029
1030 #[crate::test]
impl_handler_for_into_response()1031 async fn impl_handler_for_into_response() {
1032 let app = Router::new().route("/things", post((StatusCode::CREATED, "thing created")));
1033
1034 let client = TestClient::new(app);
1035
1036 let res = client.post("/things").send().await;
1037 assert_eq!(res.status(), StatusCode::CREATED);
1038 assert_eq!(res.text().await, "thing created");
1039 }
1040