use crate::{
body::{Bytes, Empty},
error_handling::HandleErrorLayer,
extract::{self, DefaultBodyLimit, FromRef, Path, State},
handler::{Handler, HandlerWithoutStateExt},
response::IntoResponse,
routing::{
delete, get, get_service, on, on_service, patch, patch_service,
path_router::path_for_nested_route, post, MethodFilter,
},
test_helpers::{
tracing_helpers::{capture_tracing, TracingEvent},
*,
},
BoxError, Extension, Json, Router,
};
use futures_util::stream::StreamExt;
use http::{
header::CONTENT_LENGTH,
header::{ALLOW, HOST},
HeaderMap, Method, Request, Response, StatusCode, Uri,
};
use hyper::Body;
use serde::Deserialize;
use serde_json::json;
use std::{
convert::Infallible,
future::{ready, Ready},
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
task::{Context, Poll},
time::Duration,
};
use tower::{
service_fn, timeout::TimeoutLayer, util::MapResponseLayer, ServiceBuilder, ServiceExt,
};
use tower_http::{limit::RequestBodyLimitLayer, validate_request::ValidateRequestHeaderLayer};
use tower_service::Service;
mod fallback;
mod get_to_head;
mod handle_error;
mod merge;
mod nest;
#[crate::test]
async fn hello_world() {
async fn root(_: Request
) -> &'static str {
"Hello, World!"
}
async fn foo(_: Request) -> &'static str {
"foo"
}
async fn users_create(_: Request) -> &'static str {
"users#create"
}
let app = Router::new()
.route("/", get(root).post(foo))
.route("/users", post(users_create));
let client = TestClient::new(app);
let res = client.get("/").send().await;
let body = res.text().await;
assert_eq!(body, "Hello, World!");
let res = client.post("/").send().await;
let body = res.text().await;
assert_eq!(body, "foo");
let res = client.post("/users").send().await;
let body = res.text().await;
assert_eq!(body, "users#create");
}
#[crate::test]
async fn routing() {
let app = Router::new()
.route(
"/users",
get(|_: Request| async { "users#index" })
.post(|_: Request| async { "users#create" }),
)
.route("/users/:id", get(|_: Request| async { "users#show" }))
.route(
"/users/:id/action",
get(|_: Request| async { "users#action" }),
);
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let res = client.get("/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "users#index");
let res = client.post("/users").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "users#create");
let res = client.get("/users/1").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "users#show");
let res = client.get("/users/1/action").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "users#action");
}
#[crate::test]
async fn router_type_doesnt_change() {
let app: Router = Router::new()
.route(
"/",
on(MethodFilter::GET, |_: Request| async {
"hi from GET"
})
.on(MethodFilter::POST, |_: Request| async {
"hi from POST"
}),
)
.layer(tower_http::compression::CompressionLayer::new());
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "hi from GET");
let res = client.post("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "hi from POST");
}
#[crate::test]
async fn routing_between_services() {
use std::convert::Infallible;
use tower::service_fn;
async fn handle(_: Request) -> &'static str {
"handler"
}
let app = Router::new()
.route(
"/one",
get_service(service_fn(|_: Request| async {
Ok::<_, Infallible>(Response::new(Body::from("one get")))
}))
.post_service(service_fn(|_: Request| async {
Ok::<_, Infallible>(Response::new(Body::from("one post")))
}))
.on_service(
MethodFilter::PUT,
service_fn(|_: Request| async {
Ok::<_, Infallible>(Response::new(Body::from("one put")))
}),
),
)
.route("/two", on_service(MethodFilter::GET, handle.into_service()));
let client = TestClient::new(app);
let res = client.get("/one").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one get");
let res = client.post("/one").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one post");
let res = client.put("/one").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one put");
let res = client.get("/two").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "handler");
}
#[crate::test]
async fn middleware_on_single_route() {
use tower::ServiceBuilder;
use tower_http::{compression::CompressionLayer, trace::TraceLayer};
async fn handle(_: Request) -> &'static str {
"Hello, World!"
}
let app = Router::new().route(
"/",
get(handle.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(CompressionLayer::new())
.into_inner(),
)),
);
let client = TestClient::new(app);
let res = client.get("/").send().await;
let body = res.text().await;
assert_eq!(body, "Hello, World!");
}
#[crate::test]
async fn service_in_bottom() {
async fn handler(_req: Request) -> Result, Infallible> {
Ok(Response::new(hyper::Body::empty()))
}
let app = Router::new().route("/", get_service(service_fn(handler)));
TestClient::new(app);
}
#[crate::test]
async fn wrong_method_handler() {
let app = Router::new()
.route("/", get(|| async {}).post(|| async {}))
.route("/foo", patch(|| async {}));
let client = TestClient::new(app);
let res = client.patch("/").send().await;
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST");
let res = client.patch("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.post("/foo").send().await;
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(res.headers()[ALLOW], "PATCH");
let res = client.get("/bar").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[crate::test]
async fn wrong_method_service() {
#[derive(Clone)]
struct Svc;
impl Service for Svc {
type Response = Response>;
type Error = Infallible;
type Future = Ready>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: R) -> Self::Future {
ready(Ok(Response::new(Empty::new())))
}
}
let app = Router::new()
.route("/", get_service(Svc).post_service(Svc))
.route("/foo", patch_service(Svc));
let client = TestClient::new(app);
let res = client.patch("/").send().await;
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST");
let res = client.patch("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.post("/foo").send().await;
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(res.headers()[ALLOW], "PATCH");
let res = client.get("/bar").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[crate::test]
async fn multiple_methods_for_one_handler() {
async fn root(_: Request) -> &'static str {
"Hello, World!"
}
let app = Router::new().route("/", on(MethodFilter::GET | MethodFilter::POST, root));
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.post("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[crate::test]
async fn wildcard_sees_whole_url() {
let app = Router::new().route("/api/*rest", get(|uri: Uri| async move { uri.to_string() }));
let client = TestClient::new(app);
let res = client.get("/api/foo/bar").send().await;
assert_eq!(res.text().await, "/api/foo/bar");
}
#[crate::test]
async fn middleware_applies_to_routes_above() {
let app = Router::new()
.route("/one", get(std::future::pending::<()>))
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| async move {
StatusCode::REQUEST_TIMEOUT
}))
.layer(TimeoutLayer::new(Duration::new(0, 0))),
)
.route("/two", get(|| async {}));
let client = TestClient::new(app);
let res = client.get("/one").send().await;
assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
let res = client.get("/two").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[crate::test]
async fn not_found_for_extra_trailing_slash() {
let app = Router::new().route("/foo", get(|| async {}));
let client = TestClient::new(app);
let res = client.get("/foo/").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[crate::test]
async fn not_found_for_missing_trailing_slash() {
let app = Router::new().route("/foo/", get(|| async {}));
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
#[crate::test]
async fn with_and_without_trailing_slash() {
let app = Router::new()
.route("/foo", get(|| async { "without tsr" }))
.route("/foo/", get(|| async { "with tsr" }));
let client = TestClient::new(app);
let res = client.get("/foo/").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "with tsr");
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "without tsr");
}
// for https://github.com/tokio-rs/axum/issues/420
#[crate::test]
async fn wildcard_doesnt_match_just_trailing_slash() {
let app = Router::new().route(
"/x/*path",
get(|Path(path): Path| async move { path }),
);
let client = TestClient::new(app);
let res = client.get("/x").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let res = client.get("/x/").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let res = client.get("/x/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "foo/bar");
}
#[crate::test]
async fn what_matches_wildcard() {
let app = Router::new()
.route("/*key", get(|| async { "root" }))
.route("/x/*key", get(|| async { "x" }))
.fallback(|| async { "fallback" });
let client = TestClient::new(app);
let get = |path| {
let f = client.get(path).send();
async move { f.await.text().await }
};
assert_eq!(get("/").await, "fallback");
assert_eq!(get("/a").await, "root");
assert_eq!(get("/a/").await, "root");
assert_eq!(get("/a/b").await, "root");
assert_eq!(get("/a/b/").await, "root");
assert_eq!(get("/x").await, "root");
assert_eq!(get("/x/").await, "root");
assert_eq!(get("/x/a").await, "x");
assert_eq!(get("/x/a/").await, "x");
assert_eq!(get("/x/a/b").await, "x");
assert_eq!(get("/x/a/b/").await, "x");
}
#[crate::test]
async fn static_and_dynamic_paths() {
let app = Router::new()
.route(
"/:key",
get(|Path(key): Path| async move { format!("dynamic: {key}") }),
)
.route("/foo", get(|| async { "static" }));
let client = TestClient::new(app);
let res = client.get("/bar").send().await;
assert_eq!(res.text().await, "dynamic: bar");
let res = client.get("/foo").send().await;
assert_eq!(res.text().await, "static");
}
#[crate::test]
#[should_panic(expected = "Paths must start with a `/`. Use \"/\" for root routes")]
async fn empty_route() {
let app = Router::new().route("", get(|| async {}));
TestClient::new(app);
}
#[crate::test]
async fn middleware_still_run_for_unmatched_requests() {
#[derive(Clone)]
struct CountMiddleware(S);
static COUNT: AtomicUsize = AtomicUsize::new(0);
impl Service for CountMiddleware
where
S: Service,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
self.0.poll_ready(cx)
}
fn call(&mut self, req: R) -> Self::Future {
COUNT.fetch_add(1, Ordering::SeqCst);
self.0.call(req)
}
}
let app = Router::new()
.route("/", get(|| async {}))
.layer(tower::layer::layer_fn(CountMiddleware));
let client = TestClient::new(app);
assert_eq!(COUNT.load(Ordering::SeqCst), 0);
client.get("/").send().await;
assert_eq!(COUNT.load(Ordering::SeqCst), 1);
client.get("/not-found").send().await;
assert_eq!(COUNT.load(Ordering::SeqCst), 2);
}
#[crate::test]
#[should_panic(expected = "\
Invalid route: `Router::route_service` cannot be used with `Router`s. \
Use `Router::nest` instead\
")]
async fn routing_to_router_panics() {
TestClient::new(Router::new().route_service("/", Router::new()));
}
#[crate::test]
async fn route_layer() {
let app = Router::new()
.route("/foo", get(|| async {}))
.route_layer(ValidateRequestHeaderLayer::bearer("password"));
let client = TestClient::new(app);
let res = client
.get("/foo")
.header("authorization", "Bearer password")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
let res = client.get("/not-found").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
// it would be nice if this would return `405 Method Not Allowed`
// but that requires knowing more about which method route we're calling, which we
// don't know currently since its just a generic `Service`
let res = client.post("/foo").send().await;
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
#[crate::test]
async fn different_methods_added_in_different_routes() {
let app = Router::new()
.route("/", get(|| async { "GET" }))
.route("/", post(|| async { "POST" }));
let client = TestClient::new(app);
let res = client.get("/").send().await;
let body = res.text().await;
assert_eq!(body, "GET");
let res = client.post("/").send().await;
let body = res.text().await;
assert_eq!(body, "POST");
}
#[crate::test]
#[should_panic(expected = "Cannot merge two `Router`s that both have a fallback")]
async fn merging_routers_with_fallbacks_panics() {
async fn fallback() {}
let one = Router::new().fallback(fallback);
let two = Router::new().fallback(fallback);
TestClient::new(one.merge(two));
}
#[test]
#[should_panic(expected = "Overlapping method route. Handler for `GET /foo/bar` already exists")]
fn routes_with_overlapping_method_routes() {
async fn handler() {}
let _: Router = Router::new()
.route("/foo/bar", get(handler))
.route("/foo/bar", get(handler));
}
#[test]
#[should_panic(expected = "Overlapping method route. Handler for `GET /foo/bar` already exists")]
fn merging_with_overlapping_method_routes() {
async fn handler() {}
let app: Router = Router::new().route("/foo/bar", get(handler));
_ = app.clone().merge(app);
}
#[crate::test]
async fn merging_routers_with_same_paths_but_different_methods() {
let one = Router::new().route("/", get(|| async { "GET" }));
let two = Router::new().route("/", post(|| async { "POST" }));
let client = TestClient::new(one.merge(two));
let res = client.get("/").send().await;
let body = res.text().await;
assert_eq!(body, "GET");
let res = client.post("/").send().await;
let body = res.text().await;
assert_eq!(body, "POST");
}
#[crate::test]
async fn head_content_length_through_hyper_server() {
let app = Router::new()
.route("/", get(|| async { "foo" }))
.route("/json", get(|| async { Json(json!({ "foo": 1 })) }));
let client = TestClient::new(app);
let res = client.head("/").send().await;
assert_eq!(res.headers()["content-length"], "3");
assert!(res.text().await.is_empty());
let res = client.head("/json").send().await;
assert_eq!(res.headers()["content-length"], "9");
assert!(res.text().await.is_empty());
}
#[crate::test]
async fn head_content_length_through_hyper_server_that_hits_fallback() {
let app = Router::new().fallback(|| async { "foo" });
let client = TestClient::new(app);
let res = client.head("/").send().await;
assert_eq!(res.headers()["content-length"], "3");
}
#[crate::test]
async fn head_with_middleware_applied() {
use tower_http::compression::{predicate::SizeAbove, CompressionLayer};
let app = Router::new()
.nest(
"/",
Router::new().route("/", get(|| async { "Hello, World!" })),
)
.layer(CompressionLayer::new().compress_when(SizeAbove::new(0)));
let client = TestClient::new(app);
// send GET request
let res = client
.get("/")
.header("accept-encoding", "gzip")
.send()
.await;
assert_eq!(res.headers()["transfer-encoding"], "chunked");
// cannot have `transfer-encoding: chunked` and `content-length`
assert!(!res.headers().contains_key("content-length"));
// send HEAD request
let res = client
.head("/")
.header("accept-encoding", "gzip")
.send()
.await;
// no response body so no `transfer-encoding`
assert!(!res.headers().contains_key("transfer-encoding"));
// no content-length since we cannot know it since the response
// is compressed
assert!(!res.headers().contains_key("content-length"));
}
#[crate::test]
#[should_panic(expected = "Paths must start with a `/`")]
async fn routes_must_start_with_slash() {
let app = Router::new().route(":foo", get(|| async {}));
TestClient::new(app);
}
#[crate::test]
async fn body_limited_by_default() {
let app = Router::new()
.route("/bytes", post(|_: Bytes| async {}))
.route("/string", post(|_: String| async {}))
.route("/json", post(|_: Json| async {}));
let client = TestClient::new(app);
for uri in ["/bytes", "/string", "/json"] {
println!("calling {uri}");
let stream = futures_util::stream::repeat("a".repeat(1000)).map(Ok::<_, hyper::Error>);
let body = Body::wrap_stream(stream);
let res_future = client
.post(uri)
.header("content-type", "application/json")
.body(body)
.send();
let res = tokio::time::timeout(Duration::from_secs(3), res_future)
.await
.expect("never got response");
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
}
#[crate::test]
async fn disabling_the_default_limit() {
let app = Router::new()
.route("/", post(|_: Bytes| async {}))
.layer(DefaultBodyLimit::disable());
let client = TestClient::new(app);
// `DEFAULT_LIMIT` is 2mb so make a body larger than that
let body = Body::from("a".repeat(3_000_000));
let res = client.post("/").body(body).send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[crate::test]
async fn limited_body_with_content_length() {
const LIMIT: usize = 3;
let app = Router::new()
.route(
"/",
post(|headers: HeaderMap, _body: Bytes| async move {
assert!(headers.get(CONTENT_LENGTH).is_some());
}),
)
.layer(RequestBodyLimitLayer::new(LIMIT));
let client = TestClient::new(app);
let res = client.post("/").body("a".repeat(LIMIT)).send().await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.post("/").body("a".repeat(LIMIT * 2)).send().await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[crate::test]
async fn changing_the_default_limit() {
let new_limit = 2;
let app = Router::new()
.route("/", post(|_: Bytes| async {}))
.layer(DefaultBodyLimit::max(new_limit));
let client = TestClient::new(app);
let res = client
.post("/")
.body(Body::from("a".repeat(new_limit)))
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let res = client
.post("/")
.body(Body::from("a".repeat(new_limit + 1)))
.send()
.await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[crate::test]
async fn limited_body_with_streaming_body() {
const LIMIT: usize = 3;
let app = Router::new()
.route(
"/",
post(|headers: HeaderMap, _body: Bytes| async move {
assert!(headers.get(CONTENT_LENGTH).is_none());
}),
)
.layer(RequestBodyLimitLayer::new(LIMIT));
let client = TestClient::new(app);
let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT))]);
let res = client
.post("/")
.body(Body::wrap_stream(stream))
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let stream = futures_util::stream::iter(vec![Ok::<_, hyper::Error>("a".repeat(LIMIT * 2))]);
let res = client
.post("/")
.body(Body::wrap_stream(stream))
.send()
.await;
assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE);
}
#[crate::test]
async fn extract_state() {
#[derive(Clone)]
struct AppState {
value: i32,
inner: InnerState,
}
#[derive(Clone)]
struct InnerState {
value: i32,
}
impl FromRef for InnerState {
fn from_ref(state: &AppState) -> Self {
state.inner.clone()
}
}
async fn handler(State(outer): State, State(inner): State) {
assert_eq!(outer.value, 1);
assert_eq!(inner.value, 2);
}
let state = AppState {
value: 1,
inner: InnerState { value: 2 },
};
let app = Router::new().route("/", get(handler)).with_state(state);
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[crate::test]
async fn explicitly_set_state() {
let app = Router::new()
.route_service(
"/",
get(|State(state): State<&'static str>| async move { state }).with_state("foo"),
)
.with_state("...");
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.text().await, "foo");
}
#[crate::test]
async fn layer_response_into_response() {
fn map_response(_res: Response) -> Result, impl IntoResponse> {
let headers = [("x-foo", "bar")];
let status = StatusCode::IM_A_TEAPOT;
Err((headers, status))
}
let app = Router::new()
.route("/", get(|| async {}))
.layer(MapResponseLayer::new(map_response));
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.headers()["x-foo"], "bar");
assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
}
#[allow(dead_code)]
fn method_router_fallback_with_state() {
async fn fallback(_: State<&'static str>) {}
async fn not_found(_: State<&'static str>) {}
let state = "foo";
let _: Router = Router::new()
.fallback(get(fallback).fallback(not_found))
.with_state(state);
}
#[test]
fn test_path_for_nested_route() {
assert_eq!(path_for_nested_route("/", "/"), "/");
assert_eq!(path_for_nested_route("/a", "/"), "/a");
assert_eq!(path_for_nested_route("/", "/b"), "/b");
assert_eq!(path_for_nested_route("/a/", "/"), "/a/");
assert_eq!(path_for_nested_route("/", "/b/"), "/b/");
assert_eq!(path_for_nested_route("/a", "/b"), "/a/b");
assert_eq!(path_for_nested_route("/a/", "/b"), "/a/b");
assert_eq!(path_for_nested_route("/a", "/b/"), "/a/b/");
assert_eq!(path_for_nested_route("/a/", "/b/"), "/a/b/");
}
#[crate::test]
async fn state_isnt_cloned_too_much() {
static SETUP_DONE: AtomicBool = AtomicBool::new(false);
static COUNT: AtomicUsize = AtomicUsize::new(0);
struct AppState;
impl Clone for AppState {
fn clone(&self) -> Self {
#[rustversion::since(1.65)]
#[track_caller]
fn count() {
if SETUP_DONE.load(Ordering::SeqCst) {
let bt = std::backtrace::Backtrace::force_capture();
let bt = bt
.to_string()
.lines()
.filter(|line| line.contains("axum") || line.contains("./src"))
.collect::>()
.join("\n");
println!("AppState::Clone:\n===============\n{}\n", bt);
COUNT.fetch_add(1, Ordering::SeqCst);
}
}
#[rustversion::not(since(1.65))]
fn count() {
if SETUP_DONE.load(Ordering::SeqCst) {
COUNT.fetch_add(1, Ordering::SeqCst);
}
}
count();
Self
}
}
let app = Router::new()
.route("/", get(|_: State| async {}))
.with_state(AppState);
let client = TestClient::new(app);
// ignore clones made during setup
SETUP_DONE.store(true, Ordering::SeqCst);
client.get("/").send().await;
assert_eq!(COUNT.load(Ordering::SeqCst), 4);
}
#[crate::test]
async fn logging_rejections() {
#[derive(Deserialize, Eq, PartialEq, Debug)]
#[serde(deny_unknown_fields)]
struct RejectionEvent {
message: String,
status: u16,
body: String,
rejection_type: String,
}
let events = capture_tracing::(|| async {
let app = Router::new()
.route("/extension", get(|_: Extension| async {}))
.route("/string", post(|_: String| async {}));
let client = TestClient::new(app);
assert_eq!(
client.get("/extension").send().await.status(),
StatusCode::INTERNAL_SERVER_ERROR
);
assert_eq!(
client
.post("/string")
.body(Vec::from([0, 159, 146, 150]))
.send()
.await
.status(),
StatusCode::BAD_REQUEST,
);
})
.await;
assert_eq!(
dbg!(events),
Vec::from([
TracingEvent {
fields: RejectionEvent {
message: "rejecting request".to_owned(),
status: 500,
body: "Missing request extension: Extension of \
type `core::convert::Infallible` was not found. \
Perhaps you forgot to add it? See `axum::Extension`."
.to_owned(),
rejection_type: "axum::extract::rejection::MissingExtension".to_owned(),
},
target: "axum::rejection".to_owned(),
level: "TRACE".to_owned(),
},
TracingEvent {
fields: RejectionEvent {
message: "rejecting request".to_owned(),
status: 400,
body: "Request body didn't contain valid UTF-8: \
invalid utf-8 sequence of 1 bytes from index 1"
.to_owned(),
rejection_type: "axum_core::extract::rejection::InvalidUtf8".to_owned(),
},
target: "axum::rejection".to_owned(),
level: "TRACE".to_owned(),
},
])
)
}
// https://github.com/tokio-rs/axum/issues/1955
#[crate::test]
async fn connect_going_to_custom_fallback() {
let app = Router::new().fallback(|| async { (StatusCode::NOT_FOUND, "custom fallback") });
let req = Request::builder()
.uri("example.com:443")
.method(Method::CONNECT)
.header(HOST, "example.com:443")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let text = String::from_utf8(hyper::body::to_bytes(res).await.unwrap().to_vec()).unwrap();
assert_eq!(text, "custom fallback");
}
// https://github.com/tokio-rs/axum/issues/1955
#[crate::test]
async fn connect_going_to_default_fallback() {
let app = Router::new();
let req = Request::builder()
.uri("example.com:443")
.method(Method::CONNECT)
.header(HOST, "example.com:443")
.body(Body::empty())
.unwrap();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let body = hyper::body::to_bytes(res).await.unwrap();
assert!(body.is_empty());
}
#[crate::test]
async fn impl_handler_for_into_response() {
let app = Router::new().route("/things", post((StatusCode::CREATED, "thing created")));
let client = TestClient::new(app);
let res = client.post("/things").send().await;
assert_eq!(res.status(), StatusCode::CREATED);
assert_eq!(res.text().await, "thing created");
}