1 use crate::body::HttpBody;
2 use crate::extract::{rejection::*, FromRequest, RawForm};
3 use crate::BoxError;
4 use async_trait::async_trait;
5 use axum_core::response::{IntoResponse, Response};
6 use axum_core::RequestExt;
7 use http::header::CONTENT_TYPE;
8 use http::{Request, StatusCode};
9 use serde::de::DeserializeOwned;
10 use serde::Serialize;
11 
12 /// URL encoded extractor and response.
13 ///
14 /// # As extractor
15 ///
16 /// If used as an extractor `Form` will deserialize the query parameters for `GET` and `HEAD`
17 /// requests and `application/x-www-form-urlencoded` encoded request bodies for other methods. It
18 /// supports any type that implements [`serde::Deserialize`].
19 ///
20 /// ⚠️ Since parsing form data might require consuming the request body, the `Form` extractor must be
21 /// *last* if there are multiple extractors in a handler. See ["the order of
22 /// extractors"][order-of-extractors]
23 ///
24 /// [order-of-extractors]: crate::extract#the-order-of-extractors
25 ///
26 /// ```rust
27 /// use axum::Form;
28 /// use serde::Deserialize;
29 ///
30 /// #[derive(Deserialize)]
31 /// struct SignUp {
32 ///     username: String,
33 ///     password: String,
34 /// }
35 ///
36 /// async fn accept_form(Form(sign_up): Form<SignUp>) {
37 ///     // ...
38 /// }
39 /// ```
40 ///
41 /// Note that `Content-Type: multipart/form-data` requests are not supported. Use [`Multipart`]
42 /// instead.
43 ///
44 /// # As response
45 ///
46 /// ```rust
47 /// use axum::Form;
48 /// use serde::Serialize;
49 ///
50 /// #[derive(Serialize)]
51 /// struct Payload {
52 ///     value: String,
53 /// }
54 ///
55 /// async fn handler() -> Form<Payload> {
56 ///     Form(Payload { value: "foo".to_owned() })
57 /// }
58 /// ```
59 ///
60 /// [`Multipart`]: crate::extract::Multipart
61 #[cfg_attr(docsrs, doc(cfg(feature = "form")))]
62 #[derive(Debug, Clone, Copy, Default)]
63 #[must_use]
64 pub struct Form<T>(pub T);
65 
66 #[async_trait]
67 impl<T, S, B> FromRequest<S, B> for Form<T>
68 where
69     T: DeserializeOwned,
70     B: HttpBody + Send + 'static,
71     B::Data: Send,
72     B::Error: Into<BoxError>,
73     S: Send + Sync,
74 {
75     type Rejection = FormRejection;
76 
from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection>77     async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
78         let is_get_or_head =
79             req.method() == http::Method::GET || req.method() == http::Method::HEAD;
80 
81         match req.extract().await {
82             Ok(RawForm(bytes)) => {
83                 let value =
84                     serde_urlencoded::from_bytes(&bytes).map_err(|err| -> FormRejection {
85                         if is_get_or_head {
86                             FailedToDeserializeForm::from_err(err).into()
87                         } else {
88                             FailedToDeserializeFormBody::from_err(err).into()
89                         }
90                     })?;
91                 Ok(Form(value))
92             }
93             Err(RawFormRejection::BytesRejection(r)) => Err(FormRejection::BytesRejection(r)),
94             Err(RawFormRejection::InvalidFormContentType(r)) => {
95                 Err(FormRejection::InvalidFormContentType(r))
96             }
97         }
98     }
99 }
100 
101 impl<T> IntoResponse for Form<T>
102 where
103     T: Serialize,
104 {
into_response(self) -> Response105     fn into_response(self) -> Response {
106         match serde_urlencoded::to_string(&self.0) {
107             Ok(body) => (
108                 [(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())],
109                 body,
110             )
111                 .into_response(),
112             Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
113         }
114     }
115 }
116 
117 axum_core::__impl_deref!(Form);
118 
119 #[cfg(test)]
120 mod tests {
121     use super::*;
122     use crate::{
123         body::{Empty, Full},
124         routing::{on, MethodFilter},
125         test_helpers::TestClient,
126         Router,
127     };
128     use bytes::Bytes;
129     use http::{header::CONTENT_TYPE, Method, Request};
130     use mime::APPLICATION_WWW_FORM_URLENCODED;
131     use serde::{Deserialize, Serialize};
132     use std::fmt::Debug;
133 
134     #[derive(Debug, PartialEq, Serialize, Deserialize)]
135     struct Pagination {
136         size: Option<u64>,
137         page: Option<u64>,
138     }
139 
check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T)140     async fn check_query<T: DeserializeOwned + PartialEq + Debug>(uri: impl AsRef<str>, value: T) {
141         let req = Request::builder()
142             .uri(uri.as_ref())
143             .body(Empty::<Bytes>::new())
144             .unwrap();
145         assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
146     }
147 
check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T)148     async fn check_body<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: T) {
149         let req = Request::builder()
150             .uri("http://example.com/test")
151             .method(Method::POST)
152             .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
153             .body(Full::<Bytes>::new(
154                 serde_urlencoded::to_string(&value).unwrap().into(),
155             ))
156             .unwrap();
157         assert_eq!(Form::<T>::from_request(req, &()).await.unwrap().0, value);
158     }
159 
160     #[crate::test]
test_form_query()161     async fn test_form_query() {
162         check_query(
163             "http://example.com/test",
164             Pagination {
165                 size: None,
166                 page: None,
167             },
168         )
169         .await;
170 
171         check_query(
172             "http://example.com/test?size=10",
173             Pagination {
174                 size: Some(10),
175                 page: None,
176             },
177         )
178         .await;
179 
180         check_query(
181             "http://example.com/test?size=10&page=20",
182             Pagination {
183                 size: Some(10),
184                 page: Some(20),
185             },
186         )
187         .await;
188     }
189 
190     #[crate::test]
test_form_body()191     async fn test_form_body() {
192         check_body(Pagination {
193             size: None,
194             page: None,
195         })
196         .await;
197 
198         check_body(Pagination {
199             size: Some(10),
200             page: None,
201         })
202         .await;
203 
204         check_body(Pagination {
205             size: Some(10),
206             page: Some(20),
207         })
208         .await;
209     }
210 
211     #[crate::test]
test_incorrect_content_type()212     async fn test_incorrect_content_type() {
213         let req = Request::builder()
214             .uri("http://example.com/test")
215             .method(Method::POST)
216             .header(CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
217             .body(Full::<Bytes>::new(
218                 serde_urlencoded::to_string(&Pagination {
219                     size: Some(10),
220                     page: None,
221                 })
222                 .unwrap()
223                 .into(),
224             ))
225             .unwrap();
226         assert!(matches!(
227             Form::<Pagination>::from_request(req, &())
228                 .await
229                 .unwrap_err(),
230             FormRejection::InvalidFormContentType(InvalidFormContentType)
231         ));
232     }
233 
234     #[tokio::test]
deserialize_error_status_codes()235     async fn deserialize_error_status_codes() {
236         #[allow(dead_code)]
237         #[derive(Deserialize)]
238         struct Payload {
239             a: i32,
240         }
241 
242         let app = Router::new().route(
243             "/",
244             on(
245                 MethodFilter::GET | MethodFilter::POST,
246                 |_: Form<Payload>| async {},
247             ),
248         );
249 
250         let client = TestClient::new(app);
251 
252         let res = client.get("/?a=false").send().await;
253         assert_eq!(res.status(), StatusCode::BAD_REQUEST);
254 
255         let res = client
256             .post("/")
257             .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref())
258             .body("a=false")
259             .send()
260             .await;
261         assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
262     }
263 }
264