1 //! Extractor for getting connection information from a client.
2 //!
3 //! See [`Router::into_make_service_with_connect_info`] for more details.
4 //!
5 //! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
6 
7 use super::{Extension, FromRequestParts};
8 use crate::middleware::AddExtension;
9 use async_trait::async_trait;
10 use http::request::Parts;
11 use hyper::server::conn::AddrStream;
12 use std::{
13     convert::Infallible,
14     fmt,
15     future::ready,
16     marker::PhantomData,
17     net::SocketAddr,
18     task::{Context, Poll},
19 };
20 use tower_layer::Layer;
21 use tower_service::Service;
22 
23 /// A [`MakeService`] created from a router.
24 ///
25 /// See [`Router::into_make_service_with_connect_info`] for more details.
26 ///
27 /// [`MakeService`]: tower::make::MakeService
28 /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
29 pub struct IntoMakeServiceWithConnectInfo<S, C> {
30     svc: S,
31     _connect_info: PhantomData<fn() -> C>,
32 }
33 
34 #[test]
traits()35 fn traits() {
36     use crate::test_helpers::*;
37     assert_send::<IntoMakeServiceWithConnectInfo<(), NotSendSync>>();
38 }
39 
40 impl<S, C> IntoMakeServiceWithConnectInfo<S, C> {
new(svc: S) -> Self41     pub(crate) fn new(svc: S) -> Self {
42         Self {
43             svc,
44             _connect_info: PhantomData,
45         }
46     }
47 }
48 
49 impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C>
50 where
51     S: fmt::Debug,
52 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result53     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54         f.debug_struct("IntoMakeServiceWithConnectInfo")
55             .field("svc", &self.svc)
56             .finish()
57     }
58 }
59 
60 impl<S, C> Clone for IntoMakeServiceWithConnectInfo<S, C>
61 where
62     S: Clone,
63 {
clone(&self) -> Self64     fn clone(&self) -> Self {
65         Self {
66             svc: self.svc.clone(),
67             _connect_info: PhantomData,
68         }
69     }
70 }
71 
72 /// Trait that connected IO resources implement and use to produce information
73 /// about the connection.
74 ///
75 /// The goal for this trait is to allow users to implement custom IO types that
76 /// can still provide the same connection metadata.
77 ///
78 /// See [`Router::into_make_service_with_connect_info`] for more details.
79 ///
80 /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
81 pub trait Connected<T>: Clone + Send + Sync + 'static {
82     /// Create type holding information about the connection.
connect_info(target: T) -> Self83     fn connect_info(target: T) -> Self;
84 }
85 
86 impl Connected<&AddrStream> for SocketAddr {
connect_info(target: &AddrStream) -> Self87     fn connect_info(target: &AddrStream) -> Self {
88         target.remote_addr()
89     }
90 }
91 
92 impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
93 where
94     S: Clone,
95     C: Connected<T>,
96 {
97     type Response = AddExtension<S, ConnectInfo<C>>;
98     type Error = Infallible;
99     type Future = ResponseFuture<S, C>;
100 
101     #[inline]
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>102     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103         Poll::Ready(Ok(()))
104     }
105 
call(&mut self, target: T) -> Self::Future106     fn call(&mut self, target: T) -> Self::Future {
107         let connect_info = ConnectInfo(C::connect_info(target));
108         let svc = Extension(connect_info).layer(self.svc.clone());
109         ResponseFuture::new(ready(Ok(svc)))
110     }
111 }
112 
113 opaque_future! {
114     /// Response future for [`IntoMakeServiceWithConnectInfo`].
115     pub type ResponseFuture<S, C> =
116         std::future::Ready<Result<AddExtension<S, ConnectInfo<C>>, Infallible>>;
117 }
118 
119 /// Extractor for getting connection information produced by a [`Connected`].
120 ///
121 /// Note this extractor requires you to use
122 /// [`Router::into_make_service_with_connect_info`] to run your app
123 /// otherwise it will fail at runtime.
124 ///
125 /// See [`Router::into_make_service_with_connect_info`] for more details.
126 ///
127 /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
128 #[derive(Clone, Copy, Debug)]
129 pub struct ConnectInfo<T>(pub T);
130 
131 #[async_trait]
132 impl<S, T> FromRequestParts<S> for ConnectInfo<T>
133 where
134     S: Send + Sync,
135     T: Clone + Send + Sync + 'static,
136 {
137     type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
138 
from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection>139     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
140         match Extension::<Self>::from_request_parts(parts, state).await {
141             Ok(Extension(connect_info)) => Ok(connect_info),
142             Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
143                 Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
144                 None => Err(err),
145             },
146         }
147     }
148 }
149 
150 axum_core::__impl_deref!(ConnectInfo);
151 
152 /// Middleware used to mock [`ConnectInfo`] during tests.
153 ///
154 /// If you're accidentally using [`MockConnectInfo`] and
155 /// [`Router::into_make_service_with_connect_info`] at the same time then
156 /// [`Router::into_make_service_with_connect_info`] takes precedence.
157 ///
158 /// # Example
159 ///
160 /// ```
161 /// use axum::{
162 ///     Router,
163 ///     extract::connect_info::{MockConnectInfo, ConnectInfo},
164 ///     body::Body,
165 ///     routing::get,
166 ///     http::{Request, StatusCode},
167 /// };
168 /// use std::net::SocketAddr;
169 /// use tower::ServiceExt;
170 ///
171 /// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) {}
172 ///
173 /// // this router you can run with `app.into_make_service_with_connect_info::<SocketAddr>()`
174 /// fn app() -> Router {
175 ///     Router::new().route("/", get(handler))
176 /// }
177 ///
178 /// // use this router for tests
179 /// fn test_app() -> Router {
180 ///     app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))))
181 /// }
182 ///
183 /// // #[tokio::test]
184 /// async fn some_test() {
185 ///     let app = test_app();
186 ///
187 ///     let request = Request::new(Body::empty());
188 ///     let response = app.oneshot(request).await.unwrap();
189 ///     assert_eq!(response.status(), StatusCode::OK);
190 /// }
191 /// #
192 /// # #[tokio::main]
193 /// # async fn main() {
194 /// #     some_test().await;
195 /// # }
196 /// ```
197 ///
198 /// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
199 #[derive(Clone, Copy, Debug)]
200 pub struct MockConnectInfo<T>(pub T);
201 
202 impl<S, T> Layer<S> for MockConnectInfo<T>
203 where
204     T: Clone + Send + Sync + 'static,
205 {
206     type Service = <Extension<Self> as Layer<S>>::Service;
207 
layer(&self, inner: S) -> Self::Service208     fn layer(&self, inner: S) -> Self::Service {
209         Extension(self.clone()).layer(inner)
210     }
211 }
212 
213 #[cfg(test)]
214 mod tests {
215     use super::*;
216     use crate::{routing::get, test_helpers::TestClient, Router, Server};
217     use std::net::{SocketAddr, TcpListener};
218 
219     #[crate::test]
socket_addr()220     async fn socket_addr() {
221         async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
222             format!("{addr}")
223         }
224 
225         let listener = TcpListener::bind("127.0.0.1:0").unwrap();
226         let addr = listener.local_addr().unwrap();
227 
228         let (tx, rx) = tokio::sync::oneshot::channel();
229         tokio::spawn(async move {
230             let app = Router::new().route("/", get(handler));
231             let server = Server::from_tcp(listener)
232                 .unwrap()
233                 .serve(app.into_make_service_with_connect_info::<SocketAddr>());
234             tx.send(()).unwrap();
235             server.await.expect("server error");
236         });
237         rx.await.unwrap();
238 
239         let client = reqwest::Client::new();
240 
241         let res = client.get(format!("http://{addr}")).send().await.unwrap();
242         let body = res.text().await.unwrap();
243         assert!(body.starts_with("127.0.0.1:"));
244     }
245 
246     #[crate::test]
custom()247     async fn custom() {
248         #[derive(Clone, Debug)]
249         struct MyConnectInfo {
250             value: &'static str,
251         }
252 
253         impl Connected<&AddrStream> for MyConnectInfo {
254             fn connect_info(_target: &AddrStream) -> Self {
255                 Self {
256                     value: "it worked!",
257                 }
258             }
259         }
260 
261         async fn handler(ConnectInfo(addr): ConnectInfo<MyConnectInfo>) -> &'static str {
262             addr.value
263         }
264 
265         let listener = TcpListener::bind("127.0.0.1:0").unwrap();
266         let addr = listener.local_addr().unwrap();
267 
268         let (tx, rx) = tokio::sync::oneshot::channel();
269         tokio::spawn(async move {
270             let app = Router::new().route("/", get(handler));
271             let server = Server::from_tcp(listener)
272                 .unwrap()
273                 .serve(app.into_make_service_with_connect_info::<MyConnectInfo>());
274             tx.send(()).unwrap();
275             server.await.expect("server error");
276         });
277         rx.await.unwrap();
278 
279         let client = reqwest::Client::new();
280 
281         let res = client.get(format!("http://{addr}")).send().await.unwrap();
282         let body = res.text().await.unwrap();
283         assert_eq!(body, "it worked!");
284     }
285 
286     #[crate::test]
mock_connect_info()287     async fn mock_connect_info() {
288         async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
289             format!("{addr}")
290         }
291 
292         let app = Router::new()
293             .route("/", get(handler))
294             .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
295 
296         let client = TestClient::new(app);
297 
298         let res = client.get("/").send().await;
299         let body = res.text().await;
300         assert!(body.starts_with("0.0.0.0:1337"));
301     }
302 
303     #[crate::test]
both_mock_and_real_connect_info()304     async fn both_mock_and_real_connect_info() {
305         async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
306             format!("{addr}")
307         }
308 
309         let listener = TcpListener::bind("127.0.0.1:0").unwrap();
310         let addr = listener.local_addr().unwrap();
311 
312         tokio::spawn(async move {
313             let app = Router::new()
314                 .route("/", get(handler))
315                 .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
316 
317             let server = Server::from_tcp(listener)
318                 .unwrap()
319                 .serve(app.into_make_service_with_connect_info::<SocketAddr>());
320             server.await.expect("server error");
321         });
322 
323         let client = reqwest::Client::new();
324 
325         let res = client.get(format!("http://{addr}")).send().await.unwrap();
326         let body = res.text().await.unwrap();
327         assert!(body.starts_with("127.0.0.1:"));
328     }
329 }
330