1 //! Extractor for getting connection information from a client.
2 //!
3 //! See [`Router::into_make_service_with_connect_info`] for more details.
4 //!
5 //! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
6
7 use super::{Extension, FromRequestParts};
8 use crate::middleware::AddExtension;
9 use async_trait::async_trait;
10 use http::request::Parts;
11 use hyper::server::conn::AddrStream;
12 use std::{
13 convert::Infallible,
14 fmt,
15 future::ready,
16 marker::PhantomData,
17 net::SocketAddr,
18 task::{Context, Poll},
19 };
20 use tower_layer::Layer;
21 use tower_service::Service;
22
23 /// A [`MakeService`] created from a router.
24 ///
25 /// See [`Router::into_make_service_with_connect_info`] for more details.
26 ///
27 /// [`MakeService`]: tower::make::MakeService
28 /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
29 pub struct IntoMakeServiceWithConnectInfo<S, C> {
30 svc: S,
31 _connect_info: PhantomData<fn() -> C>,
32 }
33
34 #[test]
traits()35 fn traits() {
36 use crate::test_helpers::*;
37 assert_send::<IntoMakeServiceWithConnectInfo<(), NotSendSync>>();
38 }
39
40 impl<S, C> IntoMakeServiceWithConnectInfo<S, C> {
new(svc: S) -> Self41 pub(crate) fn new(svc: S) -> Self {
42 Self {
43 svc,
44 _connect_info: PhantomData,
45 }
46 }
47 }
48
49 impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C>
50 where
51 S: fmt::Debug,
52 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 f.debug_struct("IntoMakeServiceWithConnectInfo")
55 .field("svc", &self.svc)
56 .finish()
57 }
58 }
59
60 impl<S, C> Clone for IntoMakeServiceWithConnectInfo<S, C>
61 where
62 S: Clone,
63 {
clone(&self) -> Self64 fn clone(&self) -> Self {
65 Self {
66 svc: self.svc.clone(),
67 _connect_info: PhantomData,
68 }
69 }
70 }
71
72 /// Trait that connected IO resources implement and use to produce information
73 /// about the connection.
74 ///
75 /// The goal for this trait is to allow users to implement custom IO types that
76 /// can still provide the same connection metadata.
77 ///
78 /// See [`Router::into_make_service_with_connect_info`] for more details.
79 ///
80 /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
81 pub trait Connected<T>: Clone + Send + Sync + 'static {
82 /// Create type holding information about the connection.
connect_info(target: T) -> Self83 fn connect_info(target: T) -> Self;
84 }
85
86 impl Connected<&AddrStream> for SocketAddr {
connect_info(target: &AddrStream) -> Self87 fn connect_info(target: &AddrStream) -> Self {
88 target.remote_addr()
89 }
90 }
91
92 impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
93 where
94 S: Clone,
95 C: Connected<T>,
96 {
97 type Response = AddExtension<S, ConnectInfo<C>>;
98 type Error = Infallible;
99 type Future = ResponseFuture<S, C>;
100
101 #[inline]
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>102 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103 Poll::Ready(Ok(()))
104 }
105
call(&mut self, target: T) -> Self::Future106 fn call(&mut self, target: T) -> Self::Future {
107 let connect_info = ConnectInfo(C::connect_info(target));
108 let svc = Extension(connect_info).layer(self.svc.clone());
109 ResponseFuture::new(ready(Ok(svc)))
110 }
111 }
112
113 opaque_future! {
114 /// Response future for [`IntoMakeServiceWithConnectInfo`].
115 pub type ResponseFuture<S, C> =
116 std::future::Ready<Result<AddExtension<S, ConnectInfo<C>>, Infallible>>;
117 }
118
119 /// Extractor for getting connection information produced by a [`Connected`].
120 ///
121 /// Note this extractor requires you to use
122 /// [`Router::into_make_service_with_connect_info`] to run your app
123 /// otherwise it will fail at runtime.
124 ///
125 /// See [`Router::into_make_service_with_connect_info`] for more details.
126 ///
127 /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
128 #[derive(Clone, Copy, Debug)]
129 pub struct ConnectInfo<T>(pub T);
130
131 #[async_trait]
132 impl<S, T> FromRequestParts<S> for ConnectInfo<T>
133 where
134 S: Send + Sync,
135 T: Clone + Send + Sync + 'static,
136 {
137 type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
138
from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection>139 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
140 match Extension::<Self>::from_request_parts(parts, state).await {
141 Ok(Extension(connect_info)) => Ok(connect_info),
142 Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
143 Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
144 None => Err(err),
145 },
146 }
147 }
148 }
149
150 axum_core::__impl_deref!(ConnectInfo);
151
152 /// Middleware used to mock [`ConnectInfo`] during tests.
153 ///
154 /// If you're accidentally using [`MockConnectInfo`] and
155 /// [`Router::into_make_service_with_connect_info`] at the same time then
156 /// [`Router::into_make_service_with_connect_info`] takes precedence.
157 ///
158 /// # Example
159 ///
160 /// ```
161 /// use axum::{
162 /// Router,
163 /// extract::connect_info::{MockConnectInfo, ConnectInfo},
164 /// body::Body,
165 /// routing::get,
166 /// http::{Request, StatusCode},
167 /// };
168 /// use std::net::SocketAddr;
169 /// use tower::ServiceExt;
170 ///
171 /// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) {}
172 ///
173 /// // this router you can run with `app.into_make_service_with_connect_info::<SocketAddr>()`
174 /// fn app() -> Router {
175 /// Router::new().route("/", get(handler))
176 /// }
177 ///
178 /// // use this router for tests
179 /// fn test_app() -> Router {
180 /// app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))))
181 /// }
182 ///
183 /// // #[tokio::test]
184 /// async fn some_test() {
185 /// let app = test_app();
186 ///
187 /// let request = Request::new(Body::empty());
188 /// let response = app.oneshot(request).await.unwrap();
189 /// assert_eq!(response.status(), StatusCode::OK);
190 /// }
191 /// #
192 /// # #[tokio::main]
193 /// # async fn main() {
194 /// # some_test().await;
195 /// # }
196 /// ```
197 ///
198 /// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
199 #[derive(Clone, Copy, Debug)]
200 pub struct MockConnectInfo<T>(pub T);
201
202 impl<S, T> Layer<S> for MockConnectInfo<T>
203 where
204 T: Clone + Send + Sync + 'static,
205 {
206 type Service = <Extension<Self> as Layer<S>>::Service;
207
layer(&self, inner: S) -> Self::Service208 fn layer(&self, inner: S) -> Self::Service {
209 Extension(self.clone()).layer(inner)
210 }
211 }
212
213 #[cfg(test)]
214 mod tests {
215 use super::*;
216 use crate::{routing::get, test_helpers::TestClient, Router, Server};
217 use std::net::{SocketAddr, TcpListener};
218
219 #[crate::test]
socket_addr()220 async fn socket_addr() {
221 async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
222 format!("{addr}")
223 }
224
225 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
226 let addr = listener.local_addr().unwrap();
227
228 let (tx, rx) = tokio::sync::oneshot::channel();
229 tokio::spawn(async move {
230 let app = Router::new().route("/", get(handler));
231 let server = Server::from_tcp(listener)
232 .unwrap()
233 .serve(app.into_make_service_with_connect_info::<SocketAddr>());
234 tx.send(()).unwrap();
235 server.await.expect("server error");
236 });
237 rx.await.unwrap();
238
239 let client = reqwest::Client::new();
240
241 let res = client.get(format!("http://{addr}")).send().await.unwrap();
242 let body = res.text().await.unwrap();
243 assert!(body.starts_with("127.0.0.1:"));
244 }
245
246 #[crate::test]
custom()247 async fn custom() {
248 #[derive(Clone, Debug)]
249 struct MyConnectInfo {
250 value: &'static str,
251 }
252
253 impl Connected<&AddrStream> for MyConnectInfo {
254 fn connect_info(_target: &AddrStream) -> Self {
255 Self {
256 value: "it worked!",
257 }
258 }
259 }
260
261 async fn handler(ConnectInfo(addr): ConnectInfo<MyConnectInfo>) -> &'static str {
262 addr.value
263 }
264
265 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
266 let addr = listener.local_addr().unwrap();
267
268 let (tx, rx) = tokio::sync::oneshot::channel();
269 tokio::spawn(async move {
270 let app = Router::new().route("/", get(handler));
271 let server = Server::from_tcp(listener)
272 .unwrap()
273 .serve(app.into_make_service_with_connect_info::<MyConnectInfo>());
274 tx.send(()).unwrap();
275 server.await.expect("server error");
276 });
277 rx.await.unwrap();
278
279 let client = reqwest::Client::new();
280
281 let res = client.get(format!("http://{addr}")).send().await.unwrap();
282 let body = res.text().await.unwrap();
283 assert_eq!(body, "it worked!");
284 }
285
286 #[crate::test]
mock_connect_info()287 async fn mock_connect_info() {
288 async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
289 format!("{addr}")
290 }
291
292 let app = Router::new()
293 .route("/", get(handler))
294 .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
295
296 let client = TestClient::new(app);
297
298 let res = client.get("/").send().await;
299 let body = res.text().await;
300 assert!(body.starts_with("0.0.0.0:1337"));
301 }
302
303 #[crate::test]
both_mock_and_real_connect_info()304 async fn both_mock_and_real_connect_info() {
305 async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
306 format!("{addr}")
307 }
308
309 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
310 let addr = listener.local_addr().unwrap();
311
312 tokio::spawn(async move {
313 let app = Router::new()
314 .route("/", get(handler))
315 .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
316
317 let server = Server::from_tcp(listener)
318 .unwrap()
319 .serve(app.into_make_service_with_connect_info::<SocketAddr>());
320 server.await.expect("server error");
321 });
322
323 let client = reqwest::Client::new();
324
325 let res = client.get(format!("http://{addr}")).send().await.unwrap();
326 let body = res.text().await.unwrap();
327 assert!(body.starts_with("127.0.0.1:"));
328 }
329 }
330