1 use crate::response::{IntoResponse, Response};
2 use axum_core::extract::FromRequestParts;
3 use futures_util::future::BoxFuture;
4 use http::Request;
5 use std::{
6 any::type_name,
7 convert::Infallible,
8 fmt,
9 future::Future,
10 marker::PhantomData,
11 pin::Pin,
12 task::{Context, Poll},
13 };
14 use tower_layer::Layer;
15 use tower_service::Service;
16
17 /// Create a middleware from an async function that transforms a response.
18 ///
19 /// This differs from [`tower::util::MapResponse`] in that it allows you to easily run axum-specific
20 /// extractors.
21 ///
22 /// # Example
23 ///
24 /// ```
25 /// use axum::{
26 /// Router,
27 /// routing::get,
28 /// middleware::map_response,
29 /// response::Response,
30 /// };
31 ///
32 /// async fn set_header<B>(mut response: Response<B>) -> Response<B> {
33 /// response.headers_mut().insert("x-foo", "foo".parse().unwrap());
34 /// response
35 /// }
36 ///
37 /// let app = Router::new()
38 /// .route("/", get(|| async { /* ... */ }))
39 /// .layer(map_response(set_header));
40 /// # let _: Router = app;
41 /// ```
42 ///
43 /// # Running extractors
44 ///
45 /// It is also possible to run extractors that implement [`FromRequestParts`]. These will be run
46 /// before calling the handler.
47 ///
48 /// ```
49 /// use axum::{
50 /// Router,
51 /// routing::get,
52 /// middleware::map_response,
53 /// extract::Path,
54 /// response::Response,
55 /// };
56 /// use std::collections::HashMap;
57 ///
58 /// async fn log_path_params<B>(
59 /// Path(path_params): Path<HashMap<String, String>>,
60 /// response: Response<B>,
61 /// ) -> Response<B> {
62 /// tracing::debug!(?path_params);
63 /// response
64 /// }
65 ///
66 /// let app = Router::new()
67 /// .route("/", get(|| async { /* ... */ }))
68 /// .layer(map_response(log_path_params));
69 /// # let _: Router = app;
70 /// ```
71 ///
72 /// Note that to access state you must use either [`map_response_with_state`].
73 ///
74 /// # Returning any `impl IntoResponse`
75 ///
76 /// It is also possible to return anything that implements [`IntoResponse`]
77 ///
78 /// ```
79 /// use axum::{
80 /// Router,
81 /// routing::get,
82 /// middleware::map_response,
83 /// response::{Response, IntoResponse},
84 /// };
85 /// use std::collections::HashMap;
86 ///
87 /// async fn set_header(response: Response) -> impl IntoResponse {
88 /// (
89 /// [("x-foo", "foo")],
90 /// response,
91 /// )
92 /// }
93 ///
94 /// let app = Router::new()
95 /// .route("/", get(|| async { /* ... */ }))
96 /// .layer(map_response(set_header));
97 /// # let _: Router = app;
98 /// ```
map_response<F, T>(f: F) -> MapResponseLayer<F, (), T>99 pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
100 map_response_with_state((), f)
101 }
102
103 /// Create a middleware from an async function that transforms a response, with the given state.
104 ///
105 /// See [`State`](crate::extract::State) for more details about accessing state.
106 ///
107 /// # Example
108 ///
109 /// ```rust
110 /// use axum::{
111 /// Router,
112 /// http::StatusCode,
113 /// routing::get,
114 /// response::Response,
115 /// middleware::map_response_with_state,
116 /// extract::State,
117 /// };
118 ///
119 /// #[derive(Clone)]
120 /// struct AppState { /* ... */ }
121 ///
122 /// async fn my_middleware<B>(
123 /// State(state): State<AppState>,
124 /// // you can add more extractors here but they must
125 /// // all implement `FromRequestParts`
126 /// // `FromRequest` is not allowed
127 /// response: Response<B>,
128 /// ) -> Response<B> {
129 /// // do something with `state` and `response`...
130 /// response
131 /// }
132 ///
133 /// let state = AppState { /* ... */ };
134 ///
135 /// let app = Router::new()
136 /// .route("/", get(|| async { /* ... */ }))
137 /// .route_layer(map_response_with_state(state.clone(), my_middleware))
138 /// .with_state(state);
139 /// # let _: axum::Router = app;
140 /// ```
map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T>141 pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
142 MapResponseLayer {
143 f,
144 state,
145 _extractor: PhantomData,
146 }
147 }
148
149 /// A [`tower::Layer`] from an async function that transforms a response.
150 ///
151 /// Created with [`map_response`]. See that function for more details.
152 #[must_use]
153 pub struct MapResponseLayer<F, S, T> {
154 f: F,
155 state: S,
156 _extractor: PhantomData<fn() -> T>,
157 }
158
159 impl<F, S, T> Clone for MapResponseLayer<F, S, T>
160 where
161 F: Clone,
162 S: Clone,
163 {
clone(&self) -> Self164 fn clone(&self) -> Self {
165 Self {
166 f: self.f.clone(),
167 state: self.state.clone(),
168 _extractor: self._extractor,
169 }
170 }
171 }
172
173 impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T>
174 where
175 F: Clone,
176 S: Clone,
177 {
178 type Service = MapResponse<F, S, I, T>;
179
layer(&self, inner: I) -> Self::Service180 fn layer(&self, inner: I) -> Self::Service {
181 MapResponse {
182 f: self.f.clone(),
183 state: self.state.clone(),
184 inner,
185 _extractor: PhantomData,
186 }
187 }
188 }
189
190 impl<F, S, T> fmt::Debug for MapResponseLayer<F, S, T>
191 where
192 S: fmt::Debug,
193 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("MapResponseLayer")
196 // Write out the type name, without quoting it as `&type_name::<F>()` would
197 .field("f", &format_args!("{}", type_name::<F>()))
198 .field("state", &self.state)
199 .finish()
200 }
201 }
202
203 /// A middleware created from an async function that transforms a response.
204 ///
205 /// Created with [`map_response`]. See that function for more details.
206 pub struct MapResponse<F, S, I, T> {
207 f: F,
208 inner: I,
209 state: S,
210 _extractor: PhantomData<fn() -> T>,
211 }
212
213 impl<F, S, I, T> Clone for MapResponse<F, S, I, T>
214 where
215 F: Clone,
216 I: Clone,
217 S: Clone,
218 {
clone(&self) -> Self219 fn clone(&self) -> Self {
220 Self {
221 f: self.f.clone(),
222 inner: self.inner.clone(),
223 state: self.state.clone(),
224 _extractor: self._extractor,
225 }
226 }
227 }
228
229 macro_rules! impl_service {
230 (
231 $($ty:ident),*
232 ) => {
233 #[allow(non_snake_case, unused_mut)]
234 impl<F, Fut, S, I, B, ResBody, $($ty,)*> Service<Request<B>> for MapResponse<F, S, I, ($($ty,)*)>
235 where
236 F: FnMut($($ty,)* Response<ResBody>) -> Fut + Clone + Send + 'static,
237 $( $ty: FromRequestParts<S> + Send, )*
238 Fut: Future + Send + 'static,
239 Fut::Output: IntoResponse + Send + 'static,
240 I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible>
241 + Clone
242 + Send
243 + 'static,
244 I::Future: Send + 'static,
245 B: Send + 'static,
246 ResBody: Send + 'static,
247 S: Clone + Send + Sync + 'static,
248 {
249 type Response = Response;
250 type Error = Infallible;
251 type Future = ResponseFuture;
252
253 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254 self.inner.poll_ready(cx)
255 }
256
257
258 fn call(&mut self, req: Request<B>) -> Self::Future {
259 let not_ready_inner = self.inner.clone();
260 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
261
262 let mut f = self.f.clone();
263 let _state = self.state.clone();
264
265 let future = Box::pin(async move {
266 let (mut parts, body) = req.into_parts();
267
268 $(
269 let $ty = match $ty::from_request_parts(&mut parts, &_state).await {
270 Ok(value) => value,
271 Err(rejection) => return rejection.into_response(),
272 };
273 )*
274
275 let req = Request::from_parts(parts, body);
276
277 match ready_inner.call(req).await {
278 Ok(res) => {
279 f($($ty,)* res).await.into_response()
280 }
281 Err(err) => match err {}
282 }
283 });
284
285 ResponseFuture {
286 inner: future
287 }
288 }
289 }
290 };
291 }
292
293 impl_service!();
294 impl_service!(T1);
295 impl_service!(T1, T2);
296 impl_service!(T1, T2, T3);
297 impl_service!(T1, T2, T3, T4);
298 impl_service!(T1, T2, T3, T4, T5);
299 impl_service!(T1, T2, T3, T4, T5, T6);
300 impl_service!(T1, T2, T3, T4, T5, T6, T7);
301 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
302 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
303 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
304 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
305 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
306 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
307 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
308 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
309 impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
310
311 impl<F, S, I, T> fmt::Debug for MapResponse<F, S, I, T>
312 where
313 S: fmt::Debug,
314 I: fmt::Debug,
315 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result316 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317 f.debug_struct("MapResponse")
318 .field("f", &format_args!("{}", type_name::<F>()))
319 .field("inner", &self.inner)
320 .field("state", &self.state)
321 .finish()
322 }
323 }
324
325 /// Response future for [`MapResponse`].
326 pub struct ResponseFuture {
327 inner: BoxFuture<'static, Response>,
328 }
329
330 impl Future for ResponseFuture {
331 type Output = Result<Response, Infallible>;
332
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>333 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
334 self.inner.as_mut().poll(cx).map(Ok)
335 }
336 }
337
338 impl fmt::Debug for ResponseFuture {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result339 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340 f.debug_struct("ResponseFuture").finish()
341 }
342 }
343
344 #[cfg(test)]
345 mod tests {
346 #[allow(unused_imports)]
347 use super::*;
348 use crate::{test_helpers::TestClient, Router};
349
350 #[crate::test]
works()351 async fn works() {
352 async fn add_header<B>(mut res: Response<B>) -> Response<B> {
353 res.headers_mut().insert("x-foo", "foo".parse().unwrap());
354 res
355 }
356
357 let app = Router::new().layer(map_response(add_header));
358 let client = TestClient::new(app);
359
360 let res = client.get("/").send().await;
361
362 assert_eq!(res.headers()["x-foo"], "foo");
363 }
364 }
365