1 use async_trait::async_trait; 2 use axum_core::extract::{FromRef, FromRequestParts}; 3 use http::request::Parts; 4 use std::{ 5 convert::Infallible, 6 ops::{Deref, DerefMut}, 7 }; 8 9 /// Extractor for state. 10 /// 11 /// See ["Accessing state in middleware"][state-from-middleware] for how to 12 /// access state in middleware. 13 /// 14 /// [state-from-middleware]: crate::middleware#accessing-state-in-middleware 15 /// 16 /// # With `Router` 17 /// 18 /// ``` 19 /// use axum::{Router, routing::get, extract::State}; 20 /// 21 /// // the application state 22 /// // 23 /// // here you can put configuration, database connection pools, or whatever 24 /// // state you need 25 /// // 26 /// // see "When states need to implement `Clone`" for more details on why we need 27 /// // `#[derive(Clone)]` here. 28 /// #[derive(Clone)] 29 /// struct AppState {} 30 /// 31 /// let state = AppState {}; 32 /// 33 /// // create a `Router` that holds our state 34 /// let app = Router::new() 35 /// .route("/", get(handler)) 36 /// // provide the state so the router can access it 37 /// .with_state(state); 38 /// 39 /// async fn handler( 40 /// // access the state via the `State` extractor 41 /// // extracting a state of the wrong type results in a compile error 42 /// State(state): State<AppState>, 43 /// ) { 44 /// // use `state`... 45 /// } 46 /// # let _: axum::Router = app; 47 /// ``` 48 /// 49 /// Note that `State` is an extractor, so be sure to put it before any body 50 /// extractors, see ["the order of extractors"][order-of-extractors]. 51 /// 52 /// [order-of-extractors]: crate::extract#the-order-of-extractors 53 /// 54 /// ## Combining stateful routers 55 /// 56 /// Multiple [`Router`]s can be combined with [`Router::nest`] or [`Router::merge`] 57 /// When combining [`Router`]s with one of these methods, the [`Router`]s must have 58 /// the same state type. Generally, this can be inferred automatically: 59 /// 60 /// ``` 61 /// use axum::{Router, routing::get, extract::State}; 62 /// 63 /// #[derive(Clone)] 64 /// struct AppState {} 65 /// 66 /// let state = AppState {}; 67 /// 68 /// // create a `Router` that will be nested within another 69 /// let api = Router::new() 70 /// .route("/posts", get(posts_handler)); 71 /// 72 /// let app = Router::new() 73 /// .nest("/api", api) 74 /// .with_state(state); 75 /// 76 /// async fn posts_handler(State(state): State<AppState>) { 77 /// // use `state`... 78 /// } 79 /// # let _: axum::Router = app; 80 /// ``` 81 /// 82 /// However, if you are composing [`Router`]s that are defined in separate scopes, 83 /// you may need to annotate the [`State`] type explicitly: 84 /// 85 /// ``` 86 /// use axum::{Router, routing::get, extract::State}; 87 /// 88 /// #[derive(Clone)] 89 /// struct AppState {} 90 /// 91 /// fn make_app() -> Router { 92 /// let state = AppState {}; 93 /// 94 /// Router::new() 95 /// .nest("/api", make_api()) 96 /// .with_state(state) // the outer Router's state is inferred 97 /// } 98 /// 99 /// // the inner Router must specify its state type to compose with the 100 /// // outer router 101 /// fn make_api() -> Router<AppState> { 102 /// Router::new() 103 /// .route("/posts", get(posts_handler)) 104 /// } 105 /// 106 /// async fn posts_handler(State(state): State<AppState>) { 107 /// // use `state`... 108 /// } 109 /// # let _: axum::Router = make_app(); 110 /// ``` 111 /// 112 /// In short, a [`Router`]'s generic state type defaults to `()` 113 /// (no state) unless [`Router::with_state`] is called or the value 114 /// of the generic type is given explicitly. 115 /// 116 /// [`Router`]: crate::Router 117 /// [`Router::merge`]: crate::Router::merge 118 /// [`Router::nest`]: crate::Router::nest 119 /// [`Router::with_state`]: crate::Router::with_state 120 /// 121 /// # With `MethodRouter` 122 /// 123 /// ``` 124 /// use axum::{routing::get, extract::State}; 125 /// 126 /// #[derive(Clone)] 127 /// struct AppState {} 128 /// 129 /// let state = AppState {}; 130 /// 131 /// let method_router_with_state = get(handler) 132 /// // provide the state so the handler can access it 133 /// .with_state(state); 134 /// 135 /// async fn handler(State(state): State<AppState>) { 136 /// // use `state`... 137 /// } 138 /// # async { 139 /// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap(); 140 /// # }; 141 /// ``` 142 /// 143 /// # With `Handler` 144 /// 145 /// ``` 146 /// use axum::{routing::get, handler::Handler, extract::State}; 147 /// 148 /// #[derive(Clone)] 149 /// struct AppState {} 150 /// 151 /// let state = AppState {}; 152 /// 153 /// async fn handler(State(state): State<AppState>) { 154 /// // use `state`... 155 /// } 156 /// 157 /// // provide the state so the handler can access it 158 /// let handler_with_state = handler.with_state(state); 159 /// 160 /// # async { 161 /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) 162 /// .serve(handler_with_state.into_make_service()) 163 /// .await 164 /// .expect("server failed"); 165 /// # }; 166 /// ``` 167 /// 168 /// # Substates 169 /// 170 /// [`State`] only allows a single state type but you can use [`FromRef`] to extract "substates": 171 /// 172 /// ``` 173 /// use axum::{Router, routing::get, extract::{State, FromRef}}; 174 /// 175 /// // the application state 176 /// #[derive(Clone)] 177 /// struct AppState { 178 /// // that holds some api specific state 179 /// api_state: ApiState, 180 /// } 181 /// 182 /// // the api specific state 183 /// #[derive(Clone)] 184 /// struct ApiState {} 185 /// 186 /// // support converting an `AppState` in an `ApiState` 187 /// impl FromRef<AppState> for ApiState { 188 /// fn from_ref(app_state: &AppState) -> ApiState { 189 /// app_state.api_state.clone() 190 /// } 191 /// } 192 /// 193 /// let state = AppState { 194 /// api_state: ApiState {}, 195 /// }; 196 /// 197 /// let app = Router::new() 198 /// .route("/", get(handler)) 199 /// .route("/api/users", get(api_users)) 200 /// .with_state(state); 201 /// 202 /// async fn api_users( 203 /// // access the api specific state 204 /// State(api_state): State<ApiState>, 205 /// ) { 206 /// } 207 /// 208 /// async fn handler( 209 /// // we can still access to top level state 210 /// State(state): State<AppState>, 211 /// ) { 212 /// } 213 /// # let _: axum::Router = app; 214 /// ``` 215 /// 216 /// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`. 217 /// 218 /// # For library authors 219 /// 220 /// If you're writing a library that has an extractor that needs state, this is the recommended way 221 /// to do it: 222 /// 223 /// ```rust 224 /// use axum_core::extract::{FromRequestParts, FromRef}; 225 /// use http::request::Parts; 226 /// use async_trait::async_trait; 227 /// use std::convert::Infallible; 228 /// 229 /// // the extractor your library provides 230 /// struct MyLibraryExtractor; 231 /// 232 /// #[async_trait] 233 /// impl<S> FromRequestParts<S> for MyLibraryExtractor 234 /// where 235 /// // keep `S` generic but require that it can produce a `MyLibraryState` 236 /// // this means users will have to implement `FromRef<UserState> for MyLibraryState` 237 /// MyLibraryState: FromRef<S>, 238 /// S: Send + Sync, 239 /// { 240 /// type Rejection = Infallible; 241 /// 242 /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { 243 /// // get a `MyLibraryState` from a reference to the state 244 /// let state = MyLibraryState::from_ref(state); 245 /// 246 /// // ... 247 /// # todo!() 248 /// } 249 /// } 250 /// 251 /// // the state your library needs 252 /// struct MyLibraryState { 253 /// // ... 254 /// } 255 /// ``` 256 /// 257 /// # When states need to implement `Clone` 258 /// 259 /// Your top level state type must implement `Clone` to be extractable with `State`: 260 /// 261 /// ``` 262 /// use axum::extract::State; 263 /// 264 /// // no substates, so to extract to `State<AppState>` we must implement `Clone` for `AppState` 265 /// #[derive(Clone)] 266 /// struct AppState {} 267 /// 268 /// async fn handler(State(state): State<AppState>) { 269 /// // ... 270 /// } 271 /// ``` 272 /// 273 /// This works because of [`impl<S> FromRef<S> for S where S: Clone`][`FromRef`]. 274 /// 275 /// This is also true if you're extracting substates, unless you _never_ extract the top level 276 /// state itself: 277 /// 278 /// ``` 279 /// use axum::extract::{State, FromRef}; 280 /// 281 /// // we never extract `State<AppState>`, just `State<InnerState>`. So `AppState` doesn't need to 282 /// // implement `Clone` 283 /// struct AppState { 284 /// inner: InnerState, 285 /// } 286 /// 287 /// #[derive(Clone)] 288 /// struct InnerState {} 289 /// 290 /// impl FromRef<AppState> for InnerState { 291 /// fn from_ref(app_state: &AppState) -> InnerState { 292 /// app_state.inner.clone() 293 /// } 294 /// } 295 /// 296 /// async fn api_users(State(inner): State<InnerState>) { 297 /// // ... 298 /// } 299 /// ``` 300 /// 301 /// In general however we recommend you implement `Clone` for all your state types to avoid 302 /// potential type errors. 303 /// 304 /// # Shared mutable state 305 /// 306 /// [As state is global within a `Router`][global] you can't directly get a mutable reference to 307 /// the state. 308 /// 309 /// The most basic solution is to use an `Arc<Mutex<_>>`. Which kind of mutex you need depends on 310 /// your use case. See [the tokio docs] for more details. 311 /// 312 /// Note that holding a locked `std::sync::Mutex` across `.await` points will result in `!Send` 313 /// futures which are incompatible with axum. If you need to hold a mutex across `.await` points, 314 /// consider using a `tokio::sync::Mutex` instead. 315 /// 316 /// ## Example 317 /// 318 /// ``` 319 /// use axum::{Router, routing::get, extract::State}; 320 /// use std::sync::{Arc, Mutex}; 321 /// 322 /// #[derive(Clone)] 323 /// struct AppState { 324 /// data: Arc<Mutex<String>>, 325 /// } 326 /// 327 /// async fn handler(State(state): State<AppState>) { 328 /// let mut data = state.data.lock().expect("mutex was poisoned"); 329 /// *data = "updated foo".to_owned(); 330 /// 331 /// // ... 332 /// } 333 /// 334 /// let state = AppState { 335 /// data: Arc::new(Mutex::new("foo".to_owned())), 336 /// }; 337 /// 338 /// let app = Router::new() 339 /// .route("/", get(handler)) 340 /// .with_state(state); 341 /// # let _: Router = app; 342 /// ``` 343 /// 344 /// [global]: crate::Router::with_state 345 /// [the tokio docs]: https://docs.rs/tokio/1.25.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use 346 #[derive(Debug, Default, Clone, Copy)] 347 pub struct State<S>(pub S); 348 349 #[async_trait] 350 impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState> 351 where 352 InnerState: FromRef<OuterState>, 353 OuterState: Send + Sync, 354 { 355 type Rejection = Infallible; 356 from_request_parts( _parts: &mut Parts, state: &OuterState, ) -> Result<Self, Self::Rejection>357 async fn from_request_parts( 358 _parts: &mut Parts, 359 state: &OuterState, 360 ) -> Result<Self, Self::Rejection> { 361 let inner_state = InnerState::from_ref(state); 362 Ok(Self(inner_state)) 363 } 364 } 365 366 impl<S> Deref for State<S> { 367 type Target = S; 368 deref(&self) -> &Self::Target369 fn deref(&self) -> &Self::Target { 370 &self.0 371 } 372 } 373 374 impl<S> DerefMut for State<S> { deref_mut(&mut self) -> &mut Self::Target375 fn deref_mut(&mut self) -> &mut Self::Target { 376 &mut self.0 377 } 378 } 379