//! Extractor for getting connection information from a client. //! //! See [`Router::into_make_service_with_connect_info`] for more details. //! //! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info use super::{Extension, FromRequestParts}; use crate::middleware::AddExtension; use async_trait::async_trait; use http::request::Parts; use hyper::server::conn::AddrStream; use std::{ convert::Infallible, fmt, future::ready, marker::PhantomData, net::SocketAddr, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// A [`MakeService`] created from a router. /// /// See [`Router::into_make_service_with_connect_info`] for more details. /// /// [`MakeService`]: tower::make::MakeService /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub struct IntoMakeServiceWithConnectInfo { svc: S, _connect_info: PhantomData C>, } #[test] fn traits() { use crate::test_helpers::*; assert_send::>(); } impl IntoMakeServiceWithConnectInfo { pub(crate) fn new(svc: S) -> Self { Self { svc, _connect_info: PhantomData, } } } impl fmt::Debug for IntoMakeServiceWithConnectInfo where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("IntoMakeServiceWithConnectInfo") .field("svc", &self.svc) .finish() } } impl Clone for IntoMakeServiceWithConnectInfo where S: Clone, { fn clone(&self) -> Self { Self { svc: self.svc.clone(), _connect_info: PhantomData, } } } /// Trait that connected IO resources implement and use to produce information /// about the connection. /// /// The goal for this trait is to allow users to implement custom IO types that /// can still provide the same connection metadata. /// /// See [`Router::into_make_service_with_connect_info`] for more details. /// /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub trait Connected: Clone + Send + Sync + 'static { /// Create type holding information about the connection. fn connect_info(target: T) -> Self; } impl Connected<&AddrStream> for SocketAddr { fn connect_info(target: &AddrStream) -> Self { target.remote_addr() } } impl Service for IntoMakeServiceWithConnectInfo where S: Clone, C: Connected, { type Response = AddExtension>; type Error = Infallible; type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, target: T) -> Self::Future { let connect_info = ConnectInfo(C::connect_info(target)); let svc = Extension(connect_info).layer(self.svc.clone()); ResponseFuture::new(ready(Ok(svc))) } } opaque_future! { /// Response future for [`IntoMakeServiceWithConnectInfo`]. pub type ResponseFuture = std::future::Ready>, Infallible>>; } /// Extractor for getting connection information produced by a [`Connected`]. /// /// Note this extractor requires you to use /// [`Router::into_make_service_with_connect_info`] to run your app /// otherwise it will fail at runtime. /// /// See [`Router::into_make_service_with_connect_info`] for more details. /// /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info #[derive(Clone, Copy, Debug)] pub struct ConnectInfo(pub T); #[async_trait] impl FromRequestParts for ConnectInfo where S: Send + Sync, T: Clone + Send + Sync + 'static, { type Rejection = as FromRequestParts>::Rejection; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { match Extension::::from_request_parts(parts, state).await { Ok(Extension(connect_info)) => Ok(connect_info), Err(err) => match parts.extensions.get::>() { Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())), None => Err(err), }, } } } axum_core::__impl_deref!(ConnectInfo); /// Middleware used to mock [`ConnectInfo`] during tests. /// /// If you're accidentally using [`MockConnectInfo`] and /// [`Router::into_make_service_with_connect_info`] at the same time then /// [`Router::into_make_service_with_connect_info`] takes precedence. /// /// # Example /// /// ``` /// use axum::{ /// Router, /// extract::connect_info::{MockConnectInfo, ConnectInfo}, /// body::Body, /// routing::get, /// http::{Request, StatusCode}, /// }; /// use std::net::SocketAddr; /// use tower::ServiceExt; /// /// async fn handler(ConnectInfo(addr): ConnectInfo) {} /// /// // this router you can run with `app.into_make_service_with_connect_info::()` /// fn app() -> Router { /// Router::new().route("/", get(handler)) /// } /// /// // use this router for tests /// fn test_app() -> Router { /// app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))) /// } /// /// // #[tokio::test] /// async fn some_test() { /// let app = test_app(); /// /// let request = Request::new(Body::empty()); /// let response = app.oneshot(request).await.unwrap(); /// assert_eq!(response.status(), StatusCode::OK); /// } /// # /// # #[tokio::main] /// # async fn main() { /// # some_test().await; /// # } /// ``` /// /// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info #[derive(Clone, Copy, Debug)] pub struct MockConnectInfo(pub T); impl Layer for MockConnectInfo where T: Clone + Send + Sync + 'static, { type Service = as Layer>::Service; fn layer(&self, inner: S) -> Self::Service { Extension(self.clone()).layer(inner) } } #[cfg(test)] mod tests { use super::*; use crate::{routing::get, test_helpers::TestClient, Router, Server}; use std::net::{SocketAddr, TcpListener}; #[crate::test] async fn socket_addr() { async fn handler(ConnectInfo(addr): ConnectInfo) -> String { format!("{addr}") } let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); let (tx, rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { let app = Router::new().route("/", get(handler)); let server = Server::from_tcp(listener) .unwrap() .serve(app.into_make_service_with_connect_info::()); tx.send(()).unwrap(); server.await.expect("server error"); }); rx.await.unwrap(); let client = reqwest::Client::new(); let res = client.get(format!("http://{addr}")).send().await.unwrap(); let body = res.text().await.unwrap(); assert!(body.starts_with("127.0.0.1:")); } #[crate::test] async fn custom() { #[derive(Clone, Debug)] struct MyConnectInfo { value: &'static str, } impl Connected<&AddrStream> for MyConnectInfo { fn connect_info(_target: &AddrStream) -> Self { Self { value: "it worked!", } } } async fn handler(ConnectInfo(addr): ConnectInfo) -> &'static str { addr.value } let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); let (tx, rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { let app = Router::new().route("/", get(handler)); let server = Server::from_tcp(listener) .unwrap() .serve(app.into_make_service_with_connect_info::()); tx.send(()).unwrap(); server.await.expect("server error"); }); rx.await.unwrap(); let client = reqwest::Client::new(); let res = client.get(format!("http://{addr}")).send().await.unwrap(); let body = res.text().await.unwrap(); assert_eq!(body, "it worked!"); } #[crate::test] async fn mock_connect_info() { async fn handler(ConnectInfo(addr): ConnectInfo) -> String { format!("{addr}") } let app = Router::new() .route("/", get(handler)) .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))); let client = TestClient::new(app); let res = client.get("/").send().await; let body = res.text().await; assert!(body.starts_with("0.0.0.0:1337")); } #[crate::test] async fn both_mock_and_real_connect_info() { async fn handler(ConnectInfo(addr): ConnectInfo) -> String { format!("{addr}") } let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); tokio::spawn(async move { let app = Router::new() .route("/", get(handler)) .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337)))); let server = Server::from_tcp(listener) .unwrap() .serve(app.into_make_service_with_connect_info::()); server.await.expect("server error"); }); let client = reqwest::Client::new(); let res = client.get(format!("http://{addr}")).send().await.unwrap(); let body = res.text().await.unwrap(); assert!(body.starts_with("127.0.0.1:")); } }