1 use async_trait::async_trait;
2 use axum_core::extract::{FromRef, FromRequestParts};
3 use http::request::Parts;
4 use std::{
5     convert::Infallible,
6     ops::{Deref, DerefMut},
7 };
8 
9 /// Extractor for state.
10 ///
11 /// See ["Accessing state in middleware"][state-from-middleware] for how to
12 /// access state in middleware.
13 ///
14 /// [state-from-middleware]: crate::middleware#accessing-state-in-middleware
15 ///
16 /// # With `Router`
17 ///
18 /// ```
19 /// use axum::{Router, routing::get, extract::State};
20 ///
21 /// // the application state
22 /// //
23 /// // here you can put configuration, database connection pools, or whatever
24 /// // state you need
25 /// //
26 /// // see "When states need to implement `Clone`" for more details on why we need
27 /// // `#[derive(Clone)]` here.
28 /// #[derive(Clone)]
29 /// struct AppState {}
30 ///
31 /// let state = AppState {};
32 ///
33 /// // create a `Router` that holds our state
34 /// let app = Router::new()
35 ///     .route("/", get(handler))
36 ///     // provide the state so the router can access it
37 ///     .with_state(state);
38 ///
39 /// async fn handler(
40 ///     // access the state via the `State` extractor
41 ///     // extracting a state of the wrong type results in a compile error
42 ///     State(state): State<AppState>,
43 /// ) {
44 ///     // use `state`...
45 /// }
46 /// # let _: axum::Router = app;
47 /// ```
48 ///
49 /// Note that `State` is an extractor, so be sure to put it before any body
50 /// extractors, see ["the order of extractors"][order-of-extractors].
51 ///
52 /// [order-of-extractors]: crate::extract#the-order-of-extractors
53 ///
54 /// ## Combining stateful routers
55 ///
56 /// Multiple [`Router`]s can be combined with [`Router::nest`] or [`Router::merge`]
57 /// When combining [`Router`]s with one of these methods, the [`Router`]s must have
58 /// the same state type. Generally, this can be inferred automatically:
59 ///
60 /// ```
61 /// use axum::{Router, routing::get, extract::State};
62 ///
63 /// #[derive(Clone)]
64 /// struct AppState {}
65 ///
66 /// let state = AppState {};
67 ///
68 /// // create a `Router` that will be nested within another
69 /// let api = Router::new()
70 ///     .route("/posts", get(posts_handler));
71 ///
72 /// let app = Router::new()
73 ///     .nest("/api", api)
74 ///     .with_state(state);
75 ///
76 /// async fn posts_handler(State(state): State<AppState>) {
77 ///     // use `state`...
78 /// }
79 /// # let _: axum::Router = app;
80 /// ```
81 ///
82 /// However, if you are composing [`Router`]s that are defined in separate scopes,
83 /// you may need to annotate the [`State`] type explicitly:
84 ///
85 /// ```
86 /// use axum::{Router, routing::get, extract::State};
87 ///
88 /// #[derive(Clone)]
89 /// struct AppState {}
90 ///
91 /// fn make_app() -> Router {
92 ///     let state = AppState {};
93 ///
94 ///     Router::new()
95 ///         .nest("/api", make_api())
96 ///         .with_state(state) // the outer Router's state is inferred
97 /// }
98 ///
99 /// // the inner Router must specify its state type to compose with the
100 /// // outer router
101 /// fn make_api() -> Router<AppState> {
102 ///     Router::new()
103 ///         .route("/posts", get(posts_handler))
104 /// }
105 ///
106 /// async fn posts_handler(State(state): State<AppState>) {
107 ///     // use `state`...
108 /// }
109 /// # let _: axum::Router = make_app();
110 /// ```
111 ///
112 /// In short, a [`Router`]'s generic state type defaults to `()`
113 /// (no state) unless [`Router::with_state`] is called or the value
114 /// of the generic type is given explicitly.
115 ///
116 /// [`Router`]: crate::Router
117 /// [`Router::merge`]: crate::Router::merge
118 /// [`Router::nest`]: crate::Router::nest
119 /// [`Router::with_state`]: crate::Router::with_state
120 ///
121 /// # With `MethodRouter`
122 ///
123 /// ```
124 /// use axum::{routing::get, extract::State};
125 ///
126 /// #[derive(Clone)]
127 /// struct AppState {}
128 ///
129 /// let state = AppState {};
130 ///
131 /// let method_router_with_state = get(handler)
132 ///     // provide the state so the handler can access it
133 ///     .with_state(state);
134 ///
135 /// async fn handler(State(state): State<AppState>) {
136 ///     // use `state`...
137 /// }
138 /// # async {
139 /// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap();
140 /// # };
141 /// ```
142 ///
143 /// # With `Handler`
144 ///
145 /// ```
146 /// use axum::{routing::get, handler::Handler, extract::State};
147 ///
148 /// #[derive(Clone)]
149 /// struct AppState {}
150 ///
151 /// let state = AppState {};
152 ///
153 /// async fn handler(State(state): State<AppState>) {
154 ///     // use `state`...
155 /// }
156 ///
157 /// // provide the state so the handler can access it
158 /// let handler_with_state = handler.with_state(state);
159 ///
160 /// # async {
161 /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
162 ///     .serve(handler_with_state.into_make_service())
163 ///     .await
164 ///     .expect("server failed");
165 /// # };
166 /// ```
167 ///
168 /// # Substates
169 ///
170 /// [`State`] only allows a single state type but you can use [`FromRef`] to extract "substates":
171 ///
172 /// ```
173 /// use axum::{Router, routing::get, extract::{State, FromRef}};
174 ///
175 /// // the application state
176 /// #[derive(Clone)]
177 /// struct AppState {
178 ///     // that holds some api specific state
179 ///     api_state: ApiState,
180 /// }
181 ///
182 /// // the api specific state
183 /// #[derive(Clone)]
184 /// struct ApiState {}
185 ///
186 /// // support converting an `AppState` in an `ApiState`
187 /// impl FromRef<AppState> for ApiState {
188 ///     fn from_ref(app_state: &AppState) -> ApiState {
189 ///         app_state.api_state.clone()
190 ///     }
191 /// }
192 ///
193 /// let state = AppState {
194 ///     api_state: ApiState {},
195 /// };
196 ///
197 /// let app = Router::new()
198 ///     .route("/", get(handler))
199 ///     .route("/api/users", get(api_users))
200 ///     .with_state(state);
201 ///
202 /// async fn api_users(
203 ///     // access the api specific state
204 ///     State(api_state): State<ApiState>,
205 /// ) {
206 /// }
207 ///
208 /// async fn handler(
209 ///     // we can still access to top level state
210 ///     State(state): State<AppState>,
211 /// ) {
212 /// }
213 /// # let _: axum::Router = app;
214 /// ```
215 ///
216 /// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`.
217 ///
218 /// # For library authors
219 ///
220 /// If you're writing a library that has an extractor that needs state, this is the recommended way
221 /// to do it:
222 ///
223 /// ```rust
224 /// use axum_core::extract::{FromRequestParts, FromRef};
225 /// use http::request::Parts;
226 /// use async_trait::async_trait;
227 /// use std::convert::Infallible;
228 ///
229 /// // the extractor your library provides
230 /// struct MyLibraryExtractor;
231 ///
232 /// #[async_trait]
233 /// impl<S> FromRequestParts<S> for MyLibraryExtractor
234 /// where
235 ///     // keep `S` generic but require that it can produce a `MyLibraryState`
236 ///     // this means users will have to implement `FromRef<UserState> for MyLibraryState`
237 ///     MyLibraryState: FromRef<S>,
238 ///     S: Send + Sync,
239 /// {
240 ///     type Rejection = Infallible;
241 ///
242 ///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
243 ///         // get a `MyLibraryState` from a reference to the state
244 ///         let state = MyLibraryState::from_ref(state);
245 ///
246 ///         // ...
247 ///         # todo!()
248 ///     }
249 /// }
250 ///
251 /// // the state your library needs
252 /// struct MyLibraryState {
253 ///     // ...
254 /// }
255 /// ```
256 ///
257 /// # When states need to implement `Clone`
258 ///
259 /// Your top level state type must implement `Clone` to be extractable with `State`:
260 ///
261 /// ```
262 /// use axum::extract::State;
263 ///
264 /// // no substates, so to extract to `State<AppState>` we must implement `Clone` for `AppState`
265 /// #[derive(Clone)]
266 /// struct AppState {}
267 ///
268 /// async fn handler(State(state): State<AppState>) {
269 ///     // ...
270 /// }
271 /// ```
272 ///
273 /// This works because of [`impl<S> FromRef<S> for S where S: Clone`][`FromRef`].
274 ///
275 /// This is also true if you're extracting substates, unless you _never_ extract the top level
276 /// state itself:
277 ///
278 /// ```
279 /// use axum::extract::{State, FromRef};
280 ///
281 /// // we never extract `State<AppState>`, just `State<InnerState>`. So `AppState` doesn't need to
282 /// // implement `Clone`
283 /// struct AppState {
284 ///     inner: InnerState,
285 /// }
286 ///
287 /// #[derive(Clone)]
288 /// struct InnerState {}
289 ///
290 /// impl FromRef<AppState> for InnerState {
291 ///     fn from_ref(app_state: &AppState) -> InnerState {
292 ///         app_state.inner.clone()
293 ///     }
294 /// }
295 ///
296 /// async fn api_users(State(inner): State<InnerState>) {
297 ///     // ...
298 /// }
299 /// ```
300 ///
301 /// In general however we recommend you implement `Clone` for all your state types to avoid
302 /// potential type errors.
303 ///
304 /// # Shared mutable state
305 ///
306 /// [As state is global within a `Router`][global] you can't directly get a mutable reference to
307 /// the state.
308 ///
309 /// The most basic solution is to use an `Arc<Mutex<_>>`. Which kind of mutex you need depends on
310 /// your use case. See [the tokio docs] for more details.
311 ///
312 /// Note that holding a locked `std::sync::Mutex` across `.await` points will result in `!Send`
313 /// futures which are incompatible with axum. If you need to hold a mutex across `.await` points,
314 /// consider using a `tokio::sync::Mutex` instead.
315 ///
316 /// ## Example
317 ///
318 /// ```
319 /// use axum::{Router, routing::get, extract::State};
320 /// use std::sync::{Arc, Mutex};
321 ///
322 /// #[derive(Clone)]
323 /// struct AppState {
324 ///     data: Arc<Mutex<String>>,
325 /// }
326 ///
327 /// async fn handler(State(state): State<AppState>) {
328 ///     let mut data = state.data.lock().expect("mutex was poisoned");
329 ///     *data = "updated foo".to_owned();
330 ///
331 ///     // ...
332 /// }
333 ///
334 /// let state = AppState {
335 ///     data: Arc::new(Mutex::new("foo".to_owned())),
336 /// };
337 ///
338 /// let app = Router::new()
339 ///     .route("/", get(handler))
340 ///     .with_state(state);
341 /// # let _: Router = app;
342 /// ```
343 ///
344 /// [global]: crate::Router::with_state
345 /// [the tokio docs]: https://docs.rs/tokio/1.25.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
346 #[derive(Debug, Default, Clone, Copy)]
347 pub struct State<S>(pub S);
348 
349 #[async_trait]
350 impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
351 where
352     InnerState: FromRef<OuterState>,
353     OuterState: Send + Sync,
354 {
355     type Rejection = Infallible;
356 
from_request_parts( _parts: &mut Parts, state: &OuterState, ) -> Result<Self, Self::Rejection>357     async fn from_request_parts(
358         _parts: &mut Parts,
359         state: &OuterState,
360     ) -> Result<Self, Self::Rejection> {
361         let inner_state = InnerState::from_ref(state);
362         Ok(Self(inner_state))
363     }
364 }
365 
366 impl<S> Deref for State<S> {
367     type Target = S;
368 
deref(&self) -> &Self::Target369     fn deref(&self) -> &Self::Target {
370         &self.0
371     }
372 }
373 
374 impl<S> DerefMut for State<S> {
deref_mut(&mut self) -> &mut Self::Target375     fn deref_mut(&mut self) -> &mut Self::Target {
376         &mut self.0
377     }
378 }
379