1 use crate::response::{IntoResponse, Response};
2 use axum_core::extract::{FromRequest, FromRequestParts};
3 use futures_util::future::BoxFuture;
4 use http::Request;
5 use std::{
6     any::type_name,
7     convert::Infallible,
8     fmt,
9     future::Future,
10     marker::PhantomData,
11     pin::Pin,
12     task::{Context, Poll},
13 };
14 use tower::{util::BoxCloneService, ServiceBuilder};
15 use tower_layer::Layer;
16 use tower_service::Service;
17 
18 /// Create a middleware from an async function.
19 ///
20 /// `from_fn` requires the function given to
21 ///
22 /// 1. Be an `async fn`.
23 /// 2. Take one or more [extractors] as the first arguments.
24 /// 3. Take [`Next<B>`](Next) as the final argument.
25 /// 4. Return something that implements [`IntoResponse`].
26 ///
27 /// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`].
28 ///
29 /// # Example
30 ///
31 /// ```rust
32 /// use axum::{
33 ///     Router,
34 ///     http::{self, Request},
35 ///     routing::get,
36 ///     response::Response,
37 ///     middleware::{self, Next},
38 /// };
39 ///
40 /// async fn my_middleware<B>(
41 ///     request: Request<B>,
42 ///     next: Next<B>,
43 /// ) -> Response {
44 ///     // do something with `request`...
45 ///
46 ///     let response = next.run(request).await;
47 ///
48 ///     // do something with `response`...
49 ///
50 ///     response
51 /// }
52 ///
53 /// let app = Router::new()
54 ///     .route("/", get(|| async { /* ... */ }))
55 ///     .layer(middleware::from_fn(my_middleware));
56 /// # let app: Router = app;
57 /// ```
58 ///
59 /// # Running extractors
60 ///
61 /// ```rust
62 /// use axum::{
63 ///     Router,
64 ///     extract::TypedHeader,
65 ///     http::StatusCode,
66 ///     headers::authorization::{Authorization, Bearer},
67 ///     http::Request,
68 ///     middleware::{self, Next},
69 ///     response::Response,
70 ///     routing::get,
71 /// };
72 ///
73 /// async fn auth<B>(
74 ///     // run the `TypedHeader` extractor
75 ///     TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
76 ///     // you can also add more extractors here but the last
77 ///     // extractor must implement `FromRequest` which
78 ///     // `Request` does
79 ///     request: Request<B>,
80 ///     next: Next<B>,
81 /// ) -> Result<Response, StatusCode> {
82 ///     if token_is_valid(auth.token()) {
83 ///         let response = next.run(request).await;
84 ///         Ok(response)
85 ///     } else {
86 ///         Err(StatusCode::UNAUTHORIZED)
87 ///     }
88 /// }
89 ///
90 /// fn token_is_valid(token: &str) -> bool {
91 ///     // ...
92 ///     # false
93 /// }
94 ///
95 /// let app = Router::new()
96 ///     .route("/", get(|| async { /* ... */ }))
97 ///     .route_layer(middleware::from_fn(auth));
98 /// # let app: Router = app;
99 /// ```
100 ///
101 /// [extractors]: crate::extract::FromRequest
102 /// [`State`]: crate::extract::State
from_fn<F, T>(f: F) -> FromFnLayer<F, (), T>103 pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
104     from_fn_with_state((), f)
105 }
106 
107 /// Create a middleware from an async function with the given state.
108 ///
109 /// See [`State`](crate::extract::State) for more details about accessing state.
110 ///
111 /// # Example
112 ///
113 /// ```rust
114 /// use axum::{
115 ///     Router,
116 ///     http::{Request, StatusCode},
117 ///     routing::get,
118 ///     response::{IntoResponse, Response},
119 ///     middleware::{self, Next},
120 ///     extract::State,
121 /// };
122 ///
123 /// #[derive(Clone)]
124 /// struct AppState { /* ... */ }
125 ///
126 /// async fn my_middleware<B>(
127 ///     State(state): State<AppState>,
128 ///     // you can add more extractors here but the last
129 ///     // extractor must implement `FromRequest` which
130 ///     // `Request` does
131 ///     request: Request<B>,
132 ///     next: Next<B>,
133 /// ) -> Response {
134 ///     // do something with `request`...
135 ///
136 ///     let response = next.run(request).await;
137 ///
138 ///     // do something with `response`...
139 ///
140 ///     response
141 /// }
142 ///
143 /// let state = AppState { /* ... */ };
144 ///
145 /// let app = Router::new()
146 ///     .route("/", get(|| async { /* ... */ }))
147 ///     .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
148 ///     .with_state(state);
149 /// # let _: axum::Router = app;
150 /// ```
from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T>151 pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
152     FromFnLayer {
153         f,
154         state,
155         _extractor: PhantomData,
156     }
157 }
158 
159 /// A [`tower::Layer`] from an async function.
160 ///
161 /// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
162 ///
163 /// Created with [`from_fn`]. See that function for more details.
164 #[must_use]
165 pub struct FromFnLayer<F, S, T> {
166     f: F,
167     state: S,
168     _extractor: PhantomData<fn() -> T>,
169 }
170 
171 impl<F, S, T> Clone for FromFnLayer<F, S, T>
172 where
173     F: Clone,
174     S: Clone,
175 {
clone(&self) -> Self176     fn clone(&self) -> Self {
177         Self {
178             f: self.f.clone(),
179             state: self.state.clone(),
180             _extractor: self._extractor,
181         }
182     }
183 }
184 
185 impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
186 where
187     F: Clone,
188     S: Clone,
189 {
190     type Service = FromFn<F, S, I, T>;
191 
layer(&self, inner: I) -> Self::Service192     fn layer(&self, inner: I) -> Self::Service {
193         FromFn {
194             f: self.f.clone(),
195             state: self.state.clone(),
196             inner,
197             _extractor: PhantomData,
198         }
199     }
200 }
201 
202 impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
203 where
204     S: fmt::Debug,
205 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result206     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207         f.debug_struct("FromFnLayer")
208             // Write out the type name, without quoting it as `&type_name::<F>()` would
209             .field("f", &format_args!("{}", type_name::<F>()))
210             .field("state", &self.state)
211             .finish()
212     }
213 }
214 
215 /// A middleware created from an async function.
216 ///
217 /// Created with [`from_fn`]. See that function for more details.
218 pub struct FromFn<F, S, I, T> {
219     f: F,
220     inner: I,
221     state: S,
222     _extractor: PhantomData<fn() -> T>,
223 }
224 
225 impl<F, S, I, T> Clone for FromFn<F, S, I, T>
226 where
227     F: Clone,
228     I: Clone,
229     S: Clone,
230 {
clone(&self) -> Self231     fn clone(&self) -> Self {
232         Self {
233             f: self.f.clone(),
234             inner: self.inner.clone(),
235             state: self.state.clone(),
236             _extractor: self._extractor,
237         }
238     }
239 }
240 
241 macro_rules! impl_service {
242     (
243         [$($ty:ident),*], $last:ident
244     ) => {
245         #[allow(non_snake_case, unused_mut)]
246         impl<F, Fut, Out, S, I, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, I, ($($ty,)* $last,)>
247         where
248             F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
249             $( $ty: FromRequestParts<S> + Send, )*
250             $last: FromRequest<S, B> + Send,
251             Fut: Future<Output = Out> + Send + 'static,
252             Out: IntoResponse + 'static,
253             I: Service<Request<B>, Error = Infallible>
254                 + Clone
255                 + Send
256                 + 'static,
257             I::Response: IntoResponse,
258             I::Future: Send + 'static,
259             B: Send + 'static,
260             S: Clone + Send + Sync + 'static,
261         {
262             type Response = Response;
263             type Error = Infallible;
264             type Future = ResponseFuture;
265 
266             fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
267                 self.inner.poll_ready(cx)
268             }
269 
270             fn call(&mut self, req: Request<B>) -> Self::Future {
271                 let not_ready_inner = self.inner.clone();
272                 let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
273 
274                 let mut f = self.f.clone();
275                 let state = self.state.clone();
276 
277                 let future = Box::pin(async move {
278                     let (mut parts, body) = req.into_parts();
279 
280                     $(
281                         let $ty = match $ty::from_request_parts(&mut parts, &state).await {
282                             Ok(value) => value,
283                             Err(rejection) => return rejection.into_response(),
284                         };
285                     )*
286 
287                     let req = Request::from_parts(parts, body);
288 
289                     let $last = match $last::from_request(req, &state).await {
290                         Ok(value) => value,
291                         Err(rejection) => return rejection.into_response(),
292                     };
293 
294                     let inner = ServiceBuilder::new()
295                         .boxed_clone()
296                         .map_response(IntoResponse::into_response)
297                         .service(ready_inner);
298                     let next = Next { inner };
299 
300                     f($($ty,)* $last, next).await.into_response()
301                 });
302 
303                 ResponseFuture {
304                     inner: future
305                 }
306             }
307         }
308     };
309 }
310 
311 all_the_tuples!(impl_service);
312 
313 impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
314 where
315     S: fmt::Debug,
316     I: fmt::Debug,
317 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result318     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319         f.debug_struct("FromFnLayer")
320             .field("f", &format_args!("{}", type_name::<F>()))
321             .field("inner", &self.inner)
322             .field("state", &self.state)
323             .finish()
324     }
325 }
326 
327 /// The remainder of a middleware stack, including the handler.
328 pub struct Next<B> {
329     inner: BoxCloneService<Request<B>, Response, Infallible>,
330 }
331 
332 impl<B> Next<B> {
333     /// Execute the remaining middleware stack.
run(mut self, req: Request<B>) -> Response334     pub async fn run(mut self, req: Request<B>) -> Response {
335         match self.inner.call(req).await {
336             Ok(res) => res,
337             Err(err) => match err {},
338         }
339     }
340 }
341 
342 impl<B> fmt::Debug for Next<B> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result343     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344         f.debug_struct("FromFnLayer")
345             .field("inner", &self.inner)
346             .finish()
347     }
348 }
349 
350 impl<B> Clone for Next<B> {
clone(&self) -> Self351     fn clone(&self) -> Self {
352         Self {
353             inner: self.inner.clone(),
354         }
355     }
356 }
357 
358 impl<B> Service<Request<B>> for Next<B> {
359     type Response = Response;
360     type Error = Infallible;
361     type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
362 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>363     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
364         self.inner.poll_ready(cx)
365     }
366 
call(&mut self, req: Request<B>) -> Self::Future367     fn call(&mut self, req: Request<B>) -> Self::Future {
368         self.inner.call(req)
369     }
370 }
371 
372 /// Response future for [`FromFn`].
373 pub struct ResponseFuture {
374     inner: BoxFuture<'static, Response>,
375 }
376 
377 impl Future for ResponseFuture {
378     type Output = Result<Response, Infallible>;
379 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>380     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
381         self.inner.as_mut().poll(cx).map(Ok)
382     }
383 }
384 
385 impl fmt::Debug for ResponseFuture {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result386     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387         f.debug_struct("ResponseFuture").finish()
388     }
389 }
390 
391 #[cfg(test)]
392 mod tests {
393     use super::*;
394     use crate::{body::Body, routing::get, Router};
395     use http::{HeaderMap, StatusCode};
396     use tower::ServiceExt;
397 
398     #[crate::test]
basic()399     async fn basic() {
400         async fn insert_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
401             req.headers_mut()
402                 .insert("x-axum-test", "ok".parse().unwrap());
403 
404             next.run(req).await
405         }
406 
407         async fn handle(headers: HeaderMap) -> String {
408             headers["x-axum-test"].to_str().unwrap().to_owned()
409         }
410 
411         let app = Router::new()
412             .route("/", get(handle))
413             .layer(from_fn(insert_header));
414 
415         let res = app
416             .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
417             .await
418             .unwrap();
419         assert_eq!(res.status(), StatusCode::OK);
420         let body = hyper::body::to_bytes(res).await.unwrap();
421         assert_eq!(&body[..], b"ok");
422     }
423 }
424