xref: /aosp_15_r20/tools/netsim/rust/http-proxy/src/connector.rs (revision cf78ab8cffb8fc9207af348f23af247fb04370a6)
1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use crate::error::Error;
16 use base64::{engine::general_purpose, Engine as _};
17 use std::net::SocketAddr;
18 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
19 use tokio::net::TcpStream;
20 
21 const HTTP_VERSION: &str = "1.1";
22 
23 pub type Result<T> = core::result::Result<T, Error>;
24 
25 /// Establishes a TCP connection to a target address through an HTTP proxy.
26 ///
27 /// The `Connector` handles the CONNECT request handshake with the proxy, including
28 /// optional Basic authentication.
29 #[derive(Clone)]
30 pub struct Connector {
31     proxy_addr: SocketAddr,
32     username: Option<String>,
33     password: Option<String>,
34 }
35 
36 impl Connector {
new(proxy_addr: SocketAddr, username: Option<String>, password: Option<String>) -> Self37     pub fn new(proxy_addr: SocketAddr, username: Option<String>, password: Option<String>) -> Self {
38         Connector { proxy_addr, username, password }
39     }
40 
connect(&self, addr: SocketAddr) -> Result<TcpStream>41     pub async fn connect(&self, addr: SocketAddr) -> Result<TcpStream> {
42         let mut stream = TcpStream::connect(self.proxy_addr).await?;
43 
44         // Construct the CONNECT request
45         let mut request = format!("CONNECT {} HTTP/{}\r\n", addr.to_string(), HTTP_VERSION);
46 
47         // Authentication
48         if let (Some(username), Some(password)) = (&self.username, &self.password) {
49             let encoded_auth = base64_encode(format!("{}:{}", username, password).as_bytes());
50             let auth_header = format!(
51                 "Proxy-Authorization: Basic {}\r\n",
52                 String::from_utf8_lossy(&encoded_auth)
53             );
54             // Add the header to the request
55             request.push_str(&auth_header);
56         }
57 
58         // Add the final CRLF
59         request.push_str("\r\n");
60         stream.write_all(request.as_bytes()).await?;
61 
62         // Read the proxy's response
63         let mut reader = BufReader::new(stream);
64         let mut response = String::new();
65         reader.read_line(&mut response).await?;
66         if response.starts_with(&format!("HTTP/{} 200", HTTP_VERSION)) {
67             Ok(reader.into_inner())
68         } else {
69             Err(Error::ConnectionError(addr, response.trim_end_matches("\r\n").to_string()))
70         }
71     }
72 }
73 
base64_encode(src: &[u8]) -> Vec<u8>74 fn base64_encode(src: &[u8]) -> Vec<u8> {
75     general_purpose::STANDARD.encode(src).into_bytes()
76 }
77 
78 #[cfg(test)]
79 mod tests {
80     use super::*;
81     use tokio::io::AsyncReadExt;
82     use tokio::net::{lookup_host, TcpListener};
83 
84     #[tokio::test]
test_connect() -> Result<()>85     async fn test_connect() -> Result<()> {
86         let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
87         let proxy_addr = listener.local_addr().unwrap();
88 
89         let addr: SocketAddr = lookup_host("localhost:8000").await.unwrap().next().unwrap();
90 
91         let handle = tokio::spawn(async move {
92             let (stream, _) = listener.accept().await.unwrap();
93             // Server expects a client greeting with no auth methods
94             let expected_greeting = format!("CONNECT {} HTTP/1.1\r\n", &addr);
95 
96             let mut reader = BufReader::new(stream);
97             let mut line = String::new();
98 
99             reader.read_line(&mut line).await.unwrap();
100 
101             assert_eq!(line, expected_greeting);
102 
103             // Server sends a response with no auth method selected
104             let response = "HTTP/1.1 200 Connection established\r\n\r\n";
105             let mut stream = reader.into_inner();
106             stream.write_all(response.as_bytes()).await.unwrap();
107         });
108 
109         let client = Connector::new(proxy_addr, None, None);
110 
111         client.connect(addr).await.unwrap();
112 
113         handle.await.unwrap(); // Wait for the task to complete
114 
115         Ok(())
116     }
117 
118     #[tokio::test]
test_connect_with_auth() -> Result<()>119     async fn test_connect_with_auth() -> Result<()> {
120         let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
121         let proxy_addr = listener.local_addr().unwrap();
122 
123         let addr: SocketAddr = lookup_host("localhost:8000").await.unwrap().next().unwrap();
124 
125         let handle = tokio::spawn(async move {
126             let (mut stream, _) = listener.accept().await.unwrap();
127 
128             // Server expects a client greeting with auth header
129             let expected_greeting = format!(
130                 "CONNECT {} HTTP/1.1\r\nProxy-Authorization: Basic dXNlcjpwYXNzd29yZA==\r\n\r\n",
131                 &addr
132             );
133 
134             let mut buf = [0; 1024];
135             let n = stream.read(&mut buf).await.unwrap();
136             let actual_greeting = String::from_utf8_lossy(&buf[..n]);
137 
138             assert_eq!(actual_greeting, expected_greeting);
139 
140             // Server sends a response
141             let response = "HTTP/1.1 200 Connection established\r\n\r\n";
142 
143             stream.write_all(response.as_bytes()).await.unwrap();
144         });
145 
146         let client = Connector::new(proxy_addr, Some("user".into()), Some("password".into()));
147 
148         client.connect(addr).await.unwrap();
149 
150         handle.await.unwrap(); // Wait for the task to complete
151 
152         Ok(())
153     }
154 
155     #[test]
test_proxy_base64_encode_success()156     fn test_proxy_base64_encode_success() {
157         let input = b"hello world";
158         let encoded = base64_encode(input);
159         assert_eq!(encoded, b"aGVsbG8gd29ybGQ=");
160     }
161 
162     #[test]
test_proxy_base64_encode_empty_input()163     fn test_proxy_base64_encode_empty_input() {
164         let input = b"";
165         let encoded = base64_encode(input);
166         assert_eq!(encoded, b"");
167     }
168 }
169