1 use crate::response::{IntoResponse, Response};
2 use axum_core::extract::FromRequestParts;
3 use futures_util::future::BoxFuture;
4 use http::Request;
5 use std::{
6     any::type_name,
7     convert::Infallible,
8     fmt,
9     future::Future,
10     marker::PhantomData,
11     pin::Pin,
12     task::{Context, Poll},
13 };
14 use tower_layer::Layer;
15 use tower_service::Service;
16 
17 /// Create a middleware from an async function that transforms a response.
18 ///
19 /// This differs from [`tower::util::MapResponse`] in that it allows you to easily run axum-specific
20 /// extractors.
21 ///
22 /// # Example
23 ///
24 /// ```
25 /// use axum::{
26 ///     Router,
27 ///     routing::get,
28 ///     middleware::map_response,
29 ///     response::Response,
30 /// };
31 ///
32 /// async fn set_header<B>(mut response: Response<B>) -> Response<B> {
33 ///     response.headers_mut().insert("x-foo", "foo".parse().unwrap());
34 ///     response
35 /// }
36 ///
37 /// let app = Router::new()
38 ///     .route("/", get(|| async { /* ... */ }))
39 ///     .layer(map_response(set_header));
40 /// # let _: Router = app;
41 /// ```
42 ///
43 /// # Running extractors
44 ///
45 /// It is also possible to run extractors that implement [`FromRequestParts`]. These will be run
46 /// before calling the handler.
47 ///
48 /// ```
49 /// use axum::{
50 ///     Router,
51 ///     routing::get,
52 ///     middleware::map_response,
53 ///     extract::Path,
54 ///     response::Response,
55 /// };
56 /// use std::collections::HashMap;
57 ///
58 /// async fn log_path_params<B>(
59 ///     Path(path_params): Path<HashMap<String, String>>,
60 ///     response: Response<B>,
61 /// ) -> Response<B> {
62 ///     tracing::debug!(?path_params);
63 ///     response
64 /// }
65 ///
66 /// let app = Router::new()
67 ///     .route("/", get(|| async { /* ... */ }))
68 ///     .layer(map_response(log_path_params));
69 /// # let _: Router = app;
70 /// ```
71 ///
72 /// Note that to access state you must use either [`map_response_with_state`].
73 ///
74 /// # Returning any `impl IntoResponse`
75 ///
76 /// It is also possible to return anything that implements [`IntoResponse`]
77 ///
78 /// ```
79 /// use axum::{
80 ///     Router,
81 ///     routing::get,
82 ///     middleware::map_response,
83 ///     response::{Response, IntoResponse},
84 /// };
85 /// use std::collections::HashMap;
86 ///
87 /// async fn set_header(response: Response) -> impl IntoResponse {
88 ///     (
89 ///         [("x-foo", "foo")],
90 ///         response,
91 ///     )
92 /// }
93 ///
94 /// let app = Router::new()
95 ///     .route("/", get(|| async { /* ... */ }))
96 ///     .layer(map_response(set_header));
97 /// # let _: Router = app;
98 /// ```
map_response<F, T>(f: F) -> MapResponseLayer<F, (), T>99 pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
100     map_response_with_state((), f)
101 }
102 
103 /// Create a middleware from an async function that transforms a response, with the given state.
104 ///
105 /// See [`State`](crate::extract::State) for more details about accessing state.
106 ///
107 /// # Example
108 ///
109 /// ```rust
110 /// use axum::{
111 ///     Router,
112 ///     http::StatusCode,
113 ///     routing::get,
114 ///     response::Response,
115 ///     middleware::map_response_with_state,
116 ///     extract::State,
117 /// };
118 ///
119 /// #[derive(Clone)]
120 /// struct AppState { /* ... */ }
121 ///
122 /// async fn my_middleware<B>(
123 ///     State(state): State<AppState>,
124 ///     // you can add more extractors here but they must
125 ///     // all implement `FromRequestParts`
126 ///     // `FromRequest` is not allowed
127 ///     response: Response<B>,
128 /// ) -> Response<B> {
129 ///     // do something with `state` and `response`...
130 ///     response
131 /// }
132 ///
133 /// let state = AppState { /* ... */ };
134 ///
135 /// let app = Router::new()
136 ///     .route("/", get(|| async { /* ... */ }))
137 ///     .route_layer(map_response_with_state(state.clone(), my_middleware))
138 ///     .with_state(state);
139 /// # let _: axum::Router = app;
140 /// ```
map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T>141 pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
142     MapResponseLayer {
143         f,
144         state,
145         _extractor: PhantomData,
146     }
147 }
148 
149 /// A [`tower::Layer`] from an async function that transforms a response.
150 ///
151 /// Created with [`map_response`]. See that function for more details.
152 #[must_use]
153 pub struct MapResponseLayer<F, S, T> {
154     f: F,
155     state: S,
156     _extractor: PhantomData<fn() -> T>,
157 }
158 
159 impl<F, S, T> Clone for MapResponseLayer<F, S, T>
160 where
161     F: Clone,
162     S: Clone,
163 {
clone(&self) -> Self164     fn clone(&self) -> Self {
165         Self {
166             f: self.f.clone(),
167             state: self.state.clone(),
168             _extractor: self._extractor,
169         }
170     }
171 }
172 
173 impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T>
174 where
175     F: Clone,
176     S: Clone,
177 {
178     type Service = MapResponse<F, S, I, T>;
179 
layer(&self, inner: I) -> Self::Service180     fn layer(&self, inner: I) -> Self::Service {
181         MapResponse {
182             f: self.f.clone(),
183             state: self.state.clone(),
184             inner,
185             _extractor: PhantomData,
186         }
187     }
188 }
189 
190 impl<F, S, T> fmt::Debug for MapResponseLayer<F, S, T>
191 where
192     S: fmt::Debug,
193 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result194     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195         f.debug_struct("MapResponseLayer")
196             // Write out the type name, without quoting it as `&type_name::<F>()` would
197             .field("f", &format_args!("{}", type_name::<F>()))
198             .field("state", &self.state)
199             .finish()
200     }
201 }
202 
203 /// A middleware created from an async function that transforms a response.
204 ///
205 /// Created with [`map_response`]. See that function for more details.
206 pub struct MapResponse<F, S, I, T> {
207     f: F,
208     inner: I,
209     state: S,
210     _extractor: PhantomData<fn() -> T>,
211 }
212 
213 impl<F, S, I, T> Clone for MapResponse<F, S, I, T>
214 where
215     F: Clone,
216     I: Clone,
217     S: Clone,
218 {
clone(&self) -> Self219     fn clone(&self) -> Self {
220         Self {
221             f: self.f.clone(),
222             inner: self.inner.clone(),
223             state: self.state.clone(),
224             _extractor: self._extractor,
225         }
226     }
227 }
228 
229 macro_rules! impl_service {
230     (
231         $($ty:ident),*
232     ) => {
233         #[allow(non_snake_case, unused_mut)]
234         impl<F, Fut, S, I, B, ResBody, $($ty,)*> Service<Request<B>> for MapResponse<F, S, I, ($($ty,)*)>
235         where
236             F: FnMut($($ty,)* Response<ResBody>) -> Fut + Clone + Send + 'static,
237             $( $ty: FromRequestParts<S> + Send, )*
238             Fut: Future + Send + 'static,
239             Fut::Output: IntoResponse + Send + 'static,
240             I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible>
241                 + Clone
242                 + Send
243                 + 'static,
244             I::Future: Send + 'static,
245             B: Send + 'static,
246             ResBody: Send + 'static,
247             S: Clone + Send + Sync + 'static,
248         {
249             type Response = Response;
250             type Error = Infallible;
251             type Future = ResponseFuture;
252 
253             fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254                 self.inner.poll_ready(cx)
255             }
256 
257 
258             fn call(&mut self, req: Request<B>) -> Self::Future {
259                 let not_ready_inner = self.inner.clone();
260                 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
261 
262                 let mut f = self.f.clone();
263                 let _state = self.state.clone();
264 
265                 let future = Box::pin(async move {
266                     let (mut parts, body) = req.into_parts();
267 
268                     $(
269                         let $ty = match $ty::from_request_parts(&mut parts, &_state).await {
270                             Ok(value) => value,
271                             Err(rejection) => return rejection.into_response(),
272                         };
273                     )*
274 
275                     let req = Request::from_parts(parts, body);
276 
277                     match ready_inner.call(req).await {
278                         Ok(res) => {
279                             f($($ty,)* res).await.into_response()
280                         }
281                         Err(err) => match err {}
282                     }
283                 });
284 
285                 ResponseFuture {
286                     inner: future
287                 }
288             }
289         }
290     };
291 }
292 
293 impl_service!();
294 impl_service!(T1);
295 impl_service!(T1, T2);
296 impl_service!(T1, T2, T3);
297 impl_service!(T1, T2, T3, T4);
298 impl_service!(T1, T2, T3, T4, T5);
299 impl_service!(T1, T2, T3, T4, T5, T6);
300 impl_service!(T1, T2, T3, T4, T5, T6, T7);
301 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
302 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
303 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
304 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
305 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
306 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
307 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
308 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
309 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
310 
311 impl<F, S, I, T> fmt::Debug for MapResponse<F, S, I, T>
312 where
313     S: fmt::Debug,
314     I: fmt::Debug,
315 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result316     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317         f.debug_struct("MapResponse")
318             .field("f", &format_args!("{}", type_name::<F>()))
319             .field("inner", &self.inner)
320             .field("state", &self.state)
321             .finish()
322     }
323 }
324 
325 /// Response future for [`MapResponse`].
326 pub struct ResponseFuture {
327     inner: BoxFuture<'static, Response>,
328 }
329 
330 impl Future for ResponseFuture {
331     type Output = Result<Response, Infallible>;
332 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>333     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
334         self.inner.as_mut().poll(cx).map(Ok)
335     }
336 }
337 
338 impl fmt::Debug for ResponseFuture {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result339     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340         f.debug_struct("ResponseFuture").finish()
341     }
342 }
343 
344 #[cfg(test)]
345 mod tests {
346     #[allow(unused_imports)]
347     use super::*;
348     use crate::{test_helpers::TestClient, Router};
349 
350     #[crate::test]
works()351     async fn works() {
352         async fn add_header<B>(mut res: Response<B>) -> Response<B> {
353             res.headers_mut().insert("x-foo", "foo".parse().unwrap());
354             res
355         }
356 
357         let app = Router::new().layer(map_response(add_header));
358         let client = TestClient::new(app);
359 
360         let res = client.get("/").send().await;
361 
362         assert_eq!(res.headers()["x-foo"], "foo");
363     }
364 }
365