use super::{ rejection::{FailedToResolveHost, HostRejection}, FromRequestParts, }; use async_trait::async_trait; use http::{ header::{HeaderMap, FORWARDED}, request::Parts, }; const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; /// Extractor that resolves the hostname of the request. /// /// Hostname is resolved through the following, in order: /// - `Forwarded` header /// - `X-Forwarded-Host` header /// - `Host` header /// - request target / URI /// /// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make /// sure to validate them to avoid security issues. #[derive(Debug, Clone)] pub struct Host(pub String); #[async_trait] impl FromRequestParts for Host where S: Send + Sync, { type Rejection = HostRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(host) = parse_forwarded(&parts.headers) { return Ok(Host(host.to_owned())); } if let Some(host) = parts .headers .get(X_FORWARDED_HOST_HEADER_KEY) .and_then(|host| host.to_str().ok()) { return Ok(Host(host.to_owned())); } if let Some(host) = parts .headers .get(http::header::HOST) .and_then(|host| host.to_str().ok()) { return Ok(Host(host.to_owned())); } if let Some(host) = parts.uri.host() { return Ok(Host(host.to_owned())); } Err(HostRejection::FailedToResolveHost(FailedToResolveHost)) } } #[allow(warnings)] fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { // if there are multiple `Forwarded` `HeaderMap::get` will return the first one let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?; // get the first set of values let first_value = forwarded_values.split(',').nth(0)?; // find the value of the `host` field first_value.split(';').find_map(|pair| { let (key, value) = pair.split_once('=')?; key.trim() .eq_ignore_ascii_case("host") .then(|| value.trim().trim_matches('"')) }) } #[cfg(test)] mod tests { use super::*; use crate::{routing::get, test_helpers::TestClient, Router}; use http::header::HeaderName; fn test_client() -> TestClient { async fn host_as_body(Host(host): Host) -> String { host } TestClient::new(Router::new().route("/", get(host_as_body))) } #[crate::test] async fn host_header() { let original_host = "some-domain:123"; let host = test_client() .get("/") .header(http::header::HOST, original_host) .send() .await .text() .await; assert_eq!(host, original_host); } #[crate::test] async fn x_forwarded_host_header() { let original_host = "some-domain:456"; let host = test_client() .get("/") .header(X_FORWARDED_HOST_HEADER_KEY, original_host) .send() .await .text() .await; assert_eq!(host, original_host); } #[crate::test] async fn x_forwarded_host_precedence_over_host_header() { let x_forwarded_host_header = "some-domain:456"; let host_header = "some-domain:123"; let host = test_client() .get("/") .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header) .header(http::header::HOST, host_header) .send() .await .text() .await; assert_eq!(host, x_forwarded_host_header); } #[crate::test] async fn uri_host() { let host = test_client().get("/").send().await.text().await; assert!(host.contains("127.0.0.1")); } #[test] fn forwarded_parsing() { // the basic case let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "192.0.2.60"); // is case insensitive let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "192.0.2.60"); // ipv6 let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "[2001:db8:cafe::17]:4711"); // multiple values in one header let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "192.0.2.60"); // multiple header values let headers = header_map(&[ (FORWARDED, "host=192.0.2.60"), (FORWARDED, "host=127.0.0.1"), ]); let value = parse_forwarded(&headers).unwrap(); assert_eq!(value, "192.0.2.60"); } fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap { let mut headers = HeaderMap::new(); for (key, value) in values { headers.append(key, value.parse().unwrap()); } headers } }