1 use crate::response::{IntoResponse, Response};
2 use axum_core::extract::{FromRequest, FromRequestParts};
3 use futures_util::future::BoxFuture;
4 use http::Request;
5 use std::{
6 any::type_name,
7 convert::Infallible,
8 fmt,
9 future::Future,
10 marker::PhantomData,
11 pin::Pin,
12 task::{Context, Poll},
13 };
14 use tower::{util::BoxCloneService, ServiceBuilder};
15 use tower_layer::Layer;
16 use tower_service::Service;
17
18 /// Create a middleware from an async function.
19 ///
20 /// `from_fn` requires the function given to
21 ///
22 /// 1. Be an `async fn`.
23 /// 2. Take one or more [extractors] as the first arguments.
24 /// 3. Take [`Next<B>`](Next) as the final argument.
25 /// 4. Return something that implements [`IntoResponse`].
26 ///
27 /// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`].
28 ///
29 /// # Example
30 ///
31 /// ```rust
32 /// use axum::{
33 /// Router,
34 /// http::{self, Request},
35 /// routing::get,
36 /// response::Response,
37 /// middleware::{self, Next},
38 /// };
39 ///
40 /// async fn my_middleware<B>(
41 /// request: Request<B>,
42 /// next: Next<B>,
43 /// ) -> Response {
44 /// // do something with `request`...
45 ///
46 /// let response = next.run(request).await;
47 ///
48 /// // do something with `response`...
49 ///
50 /// response
51 /// }
52 ///
53 /// let app = Router::new()
54 /// .route("/", get(|| async { /* ... */ }))
55 /// .layer(middleware::from_fn(my_middleware));
56 /// # let app: Router = app;
57 /// ```
58 ///
59 /// # Running extractors
60 ///
61 /// ```rust
62 /// use axum::{
63 /// Router,
64 /// extract::TypedHeader,
65 /// http::StatusCode,
66 /// headers::authorization::{Authorization, Bearer},
67 /// http::Request,
68 /// middleware::{self, Next},
69 /// response::Response,
70 /// routing::get,
71 /// };
72 ///
73 /// async fn auth<B>(
74 /// // run the `TypedHeader` extractor
75 /// TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
76 /// // you can also add more extractors here but the last
77 /// // extractor must implement `FromRequest` which
78 /// // `Request` does
79 /// request: Request<B>,
80 /// next: Next<B>,
81 /// ) -> Result<Response, StatusCode> {
82 /// if token_is_valid(auth.token()) {
83 /// let response = next.run(request).await;
84 /// Ok(response)
85 /// } else {
86 /// Err(StatusCode::UNAUTHORIZED)
87 /// }
88 /// }
89 ///
90 /// fn token_is_valid(token: &str) -> bool {
91 /// // ...
92 /// # false
93 /// }
94 ///
95 /// let app = Router::new()
96 /// .route("/", get(|| async { /* ... */ }))
97 /// .route_layer(middleware::from_fn(auth));
98 /// # let app: Router = app;
99 /// ```
100 ///
101 /// [extractors]: crate::extract::FromRequest
102 /// [`State`]: crate::extract::State
from_fn<F, T>(f: F) -> FromFnLayer<F, (), T>103 pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
104 from_fn_with_state((), f)
105 }
106
107 /// Create a middleware from an async function with the given state.
108 ///
109 /// See [`State`](crate::extract::State) for more details about accessing state.
110 ///
111 /// # Example
112 ///
113 /// ```rust
114 /// use axum::{
115 /// Router,
116 /// http::{Request, StatusCode},
117 /// routing::get,
118 /// response::{IntoResponse, Response},
119 /// middleware::{self, Next},
120 /// extract::State,
121 /// };
122 ///
123 /// #[derive(Clone)]
124 /// struct AppState { /* ... */ }
125 ///
126 /// async fn my_middleware<B>(
127 /// State(state): State<AppState>,
128 /// // you can add more extractors here but the last
129 /// // extractor must implement `FromRequest` which
130 /// // `Request` does
131 /// request: Request<B>,
132 /// next: Next<B>,
133 /// ) -> Response {
134 /// // do something with `request`...
135 ///
136 /// let response = next.run(request).await;
137 ///
138 /// // do something with `response`...
139 ///
140 /// response
141 /// }
142 ///
143 /// let state = AppState { /* ... */ };
144 ///
145 /// let app = Router::new()
146 /// .route("/", get(|| async { /* ... */ }))
147 /// .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
148 /// .with_state(state);
149 /// # let _: axum::Router = app;
150 /// ```
from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T>151 pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
152 FromFnLayer {
153 f,
154 state,
155 _extractor: PhantomData,
156 }
157 }
158
159 /// A [`tower::Layer`] from an async function.
160 ///
161 /// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
162 ///
163 /// Created with [`from_fn`]. See that function for more details.
164 #[must_use]
165 pub struct FromFnLayer<F, S, T> {
166 f: F,
167 state: S,
168 _extractor: PhantomData<fn() -> T>,
169 }
170
171 impl<F, S, T> Clone for FromFnLayer<F, S, T>
172 where
173 F: Clone,
174 S: Clone,
175 {
clone(&self) -> Self176 fn clone(&self) -> Self {
177 Self {
178 f: self.f.clone(),
179 state: self.state.clone(),
180 _extractor: self._extractor,
181 }
182 }
183 }
184
185 impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
186 where
187 F: Clone,
188 S: Clone,
189 {
190 type Service = FromFn<F, S, I, T>;
191
layer(&self, inner: I) -> Self::Service192 fn layer(&self, inner: I) -> Self::Service {
193 FromFn {
194 f: self.f.clone(),
195 state: self.state.clone(),
196 inner,
197 _extractor: PhantomData,
198 }
199 }
200 }
201
202 impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
203 where
204 S: fmt::Debug,
205 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 f.debug_struct("FromFnLayer")
208 // Write out the type name, without quoting it as `&type_name::<F>()` would
209 .field("f", &format_args!("{}", type_name::<F>()))
210 .field("state", &self.state)
211 .finish()
212 }
213 }
214
215 /// A middleware created from an async function.
216 ///
217 /// Created with [`from_fn`]. See that function for more details.
218 pub struct FromFn<F, S, I, T> {
219 f: F,
220 inner: I,
221 state: S,
222 _extractor: PhantomData<fn() -> T>,
223 }
224
225 impl<F, S, I, T> Clone for FromFn<F, S, I, T>
226 where
227 F: Clone,
228 I: Clone,
229 S: Clone,
230 {
clone(&self) -> Self231 fn clone(&self) -> Self {
232 Self {
233 f: self.f.clone(),
234 inner: self.inner.clone(),
235 state: self.state.clone(),
236 _extractor: self._extractor,
237 }
238 }
239 }
240
241 macro_rules! impl_service {
242 (
243 [$($ty:ident),*], $last:ident
244 ) => {
245 #[allow(non_snake_case, unused_mut)]
246 impl<F, Fut, Out, S, I, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, I, ($($ty,)* $last,)>
247 where
248 F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
249 $( $ty: FromRequestParts<S> + Send, )*
250 $last: FromRequest<S, B> + Send,
251 Fut: Future<Output = Out> + Send + 'static,
252 Out: IntoResponse + 'static,
253 I: Service<Request<B>, Error = Infallible>
254 + Clone
255 + Send
256 + 'static,
257 I::Response: IntoResponse,
258 I::Future: Send + 'static,
259 B: Send + 'static,
260 S: Clone + Send + Sync + 'static,
261 {
262 type Response = Response;
263 type Error = Infallible;
264 type Future = ResponseFuture;
265
266 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
267 self.inner.poll_ready(cx)
268 }
269
270 fn call(&mut self, req: Request<B>) -> Self::Future {
271 let not_ready_inner = self.inner.clone();
272 let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
273
274 let mut f = self.f.clone();
275 let state = self.state.clone();
276
277 let future = Box::pin(async move {
278 let (mut parts, body) = req.into_parts();
279
280 $(
281 let $ty = match $ty::from_request_parts(&mut parts, &state).await {
282 Ok(value) => value,
283 Err(rejection) => return rejection.into_response(),
284 };
285 )*
286
287 let req = Request::from_parts(parts, body);
288
289 let $last = match $last::from_request(req, &state).await {
290 Ok(value) => value,
291 Err(rejection) => return rejection.into_response(),
292 };
293
294 let inner = ServiceBuilder::new()
295 .boxed_clone()
296 .map_response(IntoResponse::into_response)
297 .service(ready_inner);
298 let next = Next { inner };
299
300 f($($ty,)* $last, next).await.into_response()
301 });
302
303 ResponseFuture {
304 inner: future
305 }
306 }
307 }
308 };
309 }
310
311 all_the_tuples!(impl_service);
312
313 impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
314 where
315 S: fmt::Debug,
316 I: fmt::Debug,
317 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result318 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319 f.debug_struct("FromFnLayer")
320 .field("f", &format_args!("{}", type_name::<F>()))
321 .field("inner", &self.inner)
322 .field("state", &self.state)
323 .finish()
324 }
325 }
326
327 /// The remainder of a middleware stack, including the handler.
328 pub struct Next<B> {
329 inner: BoxCloneService<Request<B>, Response, Infallible>,
330 }
331
332 impl<B> Next<B> {
333 /// Execute the remaining middleware stack.
run(mut self, req: Request<B>) -> Response334 pub async fn run(mut self, req: Request<B>) -> Response {
335 match self.inner.call(req).await {
336 Ok(res) => res,
337 Err(err) => match err {},
338 }
339 }
340 }
341
342 impl<B> fmt::Debug for Next<B> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 f.debug_struct("FromFnLayer")
345 .field("inner", &self.inner)
346 .finish()
347 }
348 }
349
350 impl<B> Clone for Next<B> {
clone(&self) -> Self351 fn clone(&self) -> Self {
352 Self {
353 inner: self.inner.clone(),
354 }
355 }
356 }
357
358 impl<B> Service<Request<B>> for Next<B> {
359 type Response = Response;
360 type Error = Infallible;
361 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
362
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>363 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
364 self.inner.poll_ready(cx)
365 }
366
call(&mut self, req: Request<B>) -> Self::Future367 fn call(&mut self, req: Request<B>) -> Self::Future {
368 self.inner.call(req)
369 }
370 }
371
372 /// Response future for [`FromFn`].
373 pub struct ResponseFuture {
374 inner: BoxFuture<'static, Response>,
375 }
376
377 impl Future for ResponseFuture {
378 type Output = Result<Response, Infallible>;
379
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>380 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
381 self.inner.as_mut().poll(cx).map(Ok)
382 }
383 }
384
385 impl fmt::Debug for ResponseFuture {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result386 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387 f.debug_struct("ResponseFuture").finish()
388 }
389 }
390
391 #[cfg(test)]
392 mod tests {
393 use super::*;
394 use crate::{body::Body, routing::get, Router};
395 use http::{HeaderMap, StatusCode};
396 use tower::ServiceExt;
397
398 #[crate::test]
basic()399 async fn basic() {
400 async fn insert_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
401 req.headers_mut()
402 .insert("x-axum-test", "ok".parse().unwrap());
403
404 next.run(req).await
405 }
406
407 async fn handle(headers: HeaderMap) -> String {
408 headers["x-axum-test"].to_str().unwrap().to_owned()
409 }
410
411 let app = Router::new()
412 .route("/", get(handle))
413 .layer(from_fn(insert_header));
414
415 let res = app
416 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
417 .await
418 .unwrap();
419 assert_eq!(res.status(), StatusCode::OK);
420 let body = hyper::body::to_bytes(res).await.unwrap();
421 assert_eq!(&body[..], b"ok");
422 }
423 }
424