1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use crate::error::Error;
16 use base64::{engine::general_purpose, Engine as _};
17 use std::net::SocketAddr;
18 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
19 use tokio::net::TcpStream;
20
21 const HTTP_VERSION: &str = "1.1";
22
23 pub type Result<T> = core::result::Result<T, Error>;
24
25 /// Establishes a TCP connection to a target address through an HTTP proxy.
26 ///
27 /// The `Connector` handles the CONNECT request handshake with the proxy, including
28 /// optional Basic authentication.
29 #[derive(Clone)]
30 pub struct Connector {
31 proxy_addr: SocketAddr,
32 username: Option<String>,
33 password: Option<String>,
34 }
35
36 impl Connector {
new(proxy_addr: SocketAddr, username: Option<String>, password: Option<String>) -> Self37 pub fn new(proxy_addr: SocketAddr, username: Option<String>, password: Option<String>) -> Self {
38 Connector { proxy_addr, username, password }
39 }
40
connect(&self, addr: SocketAddr) -> Result<TcpStream>41 pub async fn connect(&self, addr: SocketAddr) -> Result<TcpStream> {
42 let mut stream = TcpStream::connect(self.proxy_addr).await?;
43
44 // Construct the CONNECT request
45 let mut request = format!("CONNECT {} HTTP/{}\r\n", addr.to_string(), HTTP_VERSION);
46
47 // Authentication
48 if let (Some(username), Some(password)) = (&self.username, &self.password) {
49 let encoded_auth = base64_encode(format!("{}:{}", username, password).as_bytes());
50 let auth_header = format!(
51 "Proxy-Authorization: Basic {}\r\n",
52 String::from_utf8_lossy(&encoded_auth)
53 );
54 // Add the header to the request
55 request.push_str(&auth_header);
56 }
57
58 // Add the final CRLF
59 request.push_str("\r\n");
60 stream.write_all(request.as_bytes()).await?;
61
62 // Read the proxy's response
63 let mut reader = BufReader::new(stream);
64 let mut response = String::new();
65 reader.read_line(&mut response).await?;
66 if response.starts_with(&format!("HTTP/{} 200", HTTP_VERSION)) {
67 Ok(reader.into_inner())
68 } else {
69 Err(Error::ConnectionError(addr, response.trim_end_matches("\r\n").to_string()))
70 }
71 }
72 }
73
base64_encode(src: &[u8]) -> Vec<u8>74 fn base64_encode(src: &[u8]) -> Vec<u8> {
75 general_purpose::STANDARD.encode(src).into_bytes()
76 }
77
78 #[cfg(test)]
79 mod tests {
80 use super::*;
81 use tokio::io::AsyncReadExt;
82 use tokio::net::{lookup_host, TcpListener};
83
84 #[tokio::test]
test_connect() -> Result<()>85 async fn test_connect() -> Result<()> {
86 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
87 let proxy_addr = listener.local_addr().unwrap();
88
89 let addr: SocketAddr = lookup_host("localhost:8000").await.unwrap().next().unwrap();
90
91 let handle = tokio::spawn(async move {
92 let (stream, _) = listener.accept().await.unwrap();
93 // Server expects a client greeting with no auth methods
94 let expected_greeting = format!("CONNECT {} HTTP/1.1\r\n", &addr);
95
96 let mut reader = BufReader::new(stream);
97 let mut line = String::new();
98
99 reader.read_line(&mut line).await.unwrap();
100
101 assert_eq!(line, expected_greeting);
102
103 // Server sends a response with no auth method selected
104 let response = "HTTP/1.1 200 Connection established\r\n\r\n";
105 let mut stream = reader.into_inner();
106 stream.write_all(response.as_bytes()).await.unwrap();
107 });
108
109 let client = Connector::new(proxy_addr, None, None);
110
111 client.connect(addr).await.unwrap();
112
113 handle.await.unwrap(); // Wait for the task to complete
114
115 Ok(())
116 }
117
118 #[tokio::test]
test_connect_with_auth() -> Result<()>119 async fn test_connect_with_auth() -> Result<()> {
120 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
121 let proxy_addr = listener.local_addr().unwrap();
122
123 let addr: SocketAddr = lookup_host("localhost:8000").await.unwrap().next().unwrap();
124
125 let handle = tokio::spawn(async move {
126 let (mut stream, _) = listener.accept().await.unwrap();
127
128 // Server expects a client greeting with auth header
129 let expected_greeting = format!(
130 "CONNECT {} HTTP/1.1\r\nProxy-Authorization: Basic dXNlcjpwYXNzd29yZA==\r\n\r\n",
131 &addr
132 );
133
134 let mut buf = [0; 1024];
135 let n = stream.read(&mut buf).await.unwrap();
136 let actual_greeting = String::from_utf8_lossy(&buf[..n]);
137
138 assert_eq!(actual_greeting, expected_greeting);
139
140 // Server sends a response
141 let response = "HTTP/1.1 200 Connection established\r\n\r\n";
142
143 stream.write_all(response.as_bytes()).await.unwrap();
144 });
145
146 let client = Connector::new(proxy_addr, Some("user".into()), Some("password".into()));
147
148 client.connect(addr).await.unwrap();
149
150 handle.await.unwrap(); // Wait for the task to complete
151
152 Ok(())
153 }
154
155 #[test]
test_proxy_base64_encode_success()156 fn test_proxy_base64_encode_success() {
157 let input = b"hello world";
158 let encoded = base64_encode(input);
159 assert_eq!(encoded, b"aGVsbG8gd29ybGQ=");
160 }
161
162 #[test]
test_proxy_base64_encode_empty_input()163 fn test_proxy_base64_encode_empty_input() {
164 let input = b"";
165 let encoded = base64_encode(input);
166 assert_eq!(encoded, b"");
167 }
168 }
169