1 /*
2 * Copyright (C) 2021, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 //! DoH server frontend.
18
19 use super::client::{ClientMap, ConnectionID, CONN_ID_LEN, DNS_HEADER_SIZE, MAX_UDP_PAYLOAD_SIZE};
20 use super::config::{Config, QUICHE_IDLE_TIMEOUT_MS};
21 use super::stats::Stats;
22 use anyhow::{bail, ensure, Result};
23 use log::{debug, error, warn};
24 use std::fs::File;
25 use std::io::Write;
26 use std::os::unix::io::{AsRawFd, FromRawFd};
27 use std::sync::{Arc, LazyLock, Mutex};
28 use std::time::Duration;
29 use tokio::net::UdpSocket;
30 use tokio::runtime::{Builder, Runtime};
31 use tokio::sync::{mpsc, oneshot};
32 use tokio::task::JoinHandle;
33
34 static RUNTIME_STATIC: LazyLock<Arc<Runtime>> = LazyLock::new(|| {
35 Arc::new(
36 Builder::new_multi_thread()
37 .worker_threads(1)
38 .enable_all()
39 .thread_name("DohFrontend")
40 .build()
41 .expect("Failed to create tokio runtime"),
42 )
43 });
44
45 /// Command used by worker_thread itself.
46 #[derive(Debug)]
47 enum InternalCommand {
48 MaybeWrite { connection_id: ConnectionID },
49 }
50
51 /// Commands that DohFrontend to ask its worker_thread for.
52 #[derive(Debug)]
53 enum ControlCommand {
54 Stats { resp: oneshot::Sender<Stats> },
55 StatsClearQueries,
56 CloseConnection,
57 }
58
59 /// Frontend object.
60 #[derive(Debug)]
61 pub struct DohFrontend {
62 // Socket address the frontend listens to.
63 listen_socket_addr: std::net::SocketAddr,
64
65 // Socket address the backend listens to.
66 backend_socket_addr: std::net::SocketAddr,
67
68 /// The content of the certificate.
69 certificate: String,
70
71 /// The content of the private key.
72 private_key: String,
73
74 // The thread listening to frontend socket and backend socket
75 // and processing the messages.
76 worker_thread: Option<JoinHandle<Result<()>>>,
77
78 // Custom runtime configuration to control the behavior of the worker thread.
79 // It's shared with the worker thread.
80 // TODO: use channel to update worker_thread configuration.
81 config: Arc<Mutex<Config>>,
82
83 // Caches the latest stats so that the stats remains after worker_thread stops.
84 latest_stats: Stats,
85
86 // It is wrapped as Option because the channel is not created in DohFrontend construction.
87 command_tx: Option<mpsc::UnboundedSender<ControlCommand>>,
88 }
89
90 /// The parameters passed to the worker thread.
91 struct WorkerParams {
92 frontend_socket: std::net::UdpSocket,
93 backend_socket: std::net::UdpSocket,
94 clients: ClientMap,
95 config: Arc<Mutex<Config>>,
96 command_rx: mpsc::UnboundedReceiver<ControlCommand>,
97 }
98
99 impl DohFrontend {
new( listen: std::net::SocketAddr, backend: std::net::SocketAddr, ) -> Result<Box<DohFrontend>>100 pub fn new(
101 listen: std::net::SocketAddr,
102 backend: std::net::SocketAddr,
103 ) -> Result<Box<DohFrontend>> {
104 let doh = Box::new(DohFrontend {
105 listen_socket_addr: listen,
106 backend_socket_addr: backend,
107 certificate: String::new(),
108 private_key: String::new(),
109 worker_thread: None,
110 config: Arc::new(Mutex::new(Config::new())),
111 latest_stats: Stats::new(),
112 command_tx: None,
113 });
114 debug!("DohFrontend created: {:?}", doh);
115 Ok(doh)
116 }
117
start(&mut self) -> Result<()>118 pub fn start(&mut self) -> Result<()> {
119 ensure!(self.worker_thread.is_none(), "Worker thread has been running");
120 ensure!(!self.certificate.is_empty(), "certificate is empty");
121 ensure!(!self.private_key.is_empty(), "private_key is empty");
122
123 // Doing error handling here is much simpler.
124 let params = match self.init_worker_thread_params() {
125 Ok(v) => v,
126 Err(e) => return Err(e.context("init_worker_thread_params failed")),
127 };
128
129 self.worker_thread = Some(RUNTIME_STATIC.spawn(worker_thread(params)));
130 Ok(())
131 }
132
stop(&mut self) -> Result<()>133 pub fn stop(&mut self) -> Result<()> {
134 debug!("DohFrontend: stopping: {:?}", self);
135 if let Some(worker_thread) = self.worker_thread.take() {
136 // Update latest_stats before stopping worker_thread.
137 let _ = self.request_stats();
138
139 self.command_tx.as_ref().unwrap().send(ControlCommand::CloseConnection)?;
140 if let Err(e) = self.wait_for_connections_closed() {
141 warn!("wait_for_connections_closed failed: {}", e);
142 }
143
144 worker_thread.abort();
145 RUNTIME_STATIC.block_on(async {
146 debug!("worker_thread result: {:?}", worker_thread.await);
147 })
148 }
149
150 debug!("DohFrontend: stopped: {:?}", self);
151 Ok(())
152 }
153
set_certificate(&mut self, certificate: &str) -> Result<()>154 pub fn set_certificate(&mut self, certificate: &str) -> Result<()> {
155 self.certificate = certificate.to_string();
156 Ok(())
157 }
158
set_private_key(&mut self, private_key: &str) -> Result<()>159 pub fn set_private_key(&mut self, private_key: &str) -> Result<()> {
160 self.private_key = private_key.to_string();
161 Ok(())
162 }
163
set_delay_queries(&self, value: i32) -> Result<()>164 pub fn set_delay_queries(&self, value: i32) -> Result<()> {
165 self.config.lock().unwrap().delay_queries = value;
166 Ok(())
167 }
168
set_max_idle_timeout(&self, value: u64) -> Result<()>169 pub fn set_max_idle_timeout(&self, value: u64) -> Result<()> {
170 self.config.lock().unwrap().max_idle_timeout = value;
171 Ok(())
172 }
173
set_max_buffer_size(&self, value: u64) -> Result<()>174 pub fn set_max_buffer_size(&self, value: u64) -> Result<()> {
175 self.config.lock().unwrap().max_buffer_size = value;
176 Ok(())
177 }
178
set_max_streams_bidi(&self, value: u64) -> Result<()>179 pub fn set_max_streams_bidi(&self, value: u64) -> Result<()> {
180 self.config.lock().unwrap().max_streams_bidi = value;
181 Ok(())
182 }
183
block_sending(&self, value: bool) -> Result<()>184 pub fn block_sending(&self, value: bool) -> Result<()> {
185 self.config.lock().unwrap().block_sending = value;
186 Ok(())
187 }
188
set_reset_stream_id(&self, value: u64) -> Result<()>189 pub fn set_reset_stream_id(&self, value: u64) -> Result<()> {
190 self.config.lock().unwrap().reset_stream_id = Some(value);
191 Ok(())
192 }
193
request_stats(&mut self) -> Result<Stats>194 pub fn request_stats(&mut self) -> Result<Stats> {
195 ensure!(
196 self.command_tx.is_some(),
197 "command_tx is None because worker thread not yet initialized"
198 );
199 let command_tx = self.command_tx.as_ref().unwrap();
200
201 if command_tx.is_closed() {
202 return Ok(self.latest_stats.clone());
203 }
204
205 let (resp_tx, resp_rx) = oneshot::channel();
206 command_tx.send(ControlCommand::Stats { resp: resp_tx })?;
207
208 match RUNTIME_STATIC
209 .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await })
210 {
211 Ok(v) => match v {
212 Ok(stats) => {
213 self.latest_stats = stats.clone();
214 Ok(stats)
215 }
216 Err(e) => bail!(e),
217 },
218 Err(e) => bail!(e),
219 }
220 }
221
stats_clear_queries(&self) -> Result<()>222 pub fn stats_clear_queries(&self) -> Result<()> {
223 ensure!(
224 self.command_tx.is_some(),
225 "command_tx is None because worker thread not yet initialized"
226 );
227 return self
228 .command_tx
229 .as_ref()
230 .unwrap()
231 .send(ControlCommand::StatsClearQueries)
232 .or_else(|e| bail!(e));
233 }
234
init_worker_thread_params(&mut self) -> Result<WorkerParams>235 fn init_worker_thread_params(&mut self) -> Result<WorkerParams> {
236 let bind_addr =
237 if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
238 let backend_socket = std::net::UdpSocket::bind(bind_addr)?;
239 backend_socket.connect(self.backend_socket_addr)?;
240 backend_socket.set_nonblocking(true)?;
241
242 let frontend_socket = bind_udp_socket_retry(self.listen_socket_addr)?;
243 frontend_socket.set_nonblocking(true)?;
244
245 let clients = ClientMap::new(create_quiche_config(
246 self.certificate.to_string(),
247 self.private_key.to_string(),
248 self.config.clone(),
249 )?)?;
250
251 let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>();
252 self.command_tx = Some(command_tx);
253
254 Ok(WorkerParams {
255 frontend_socket,
256 backend_socket,
257 clients,
258 config: self.config.clone(),
259 command_rx,
260 })
261 }
262
wait_for_connections_closed(&mut self) -> Result<()>263 fn wait_for_connections_closed(&mut self) -> Result<()> {
264 for _ in 0..3 {
265 std::thread::sleep(Duration::from_millis(50));
266 match self.request_stats() {
267 Ok(stats) if stats.alive_connections == 0 => return Ok(()),
268 Ok(_) => (),
269
270 // The worker thread is down. No connection is alive.
271 Err(_) => return Ok(()),
272 }
273 }
274 bail!("Some connections still alive")
275 }
276 }
277
worker_thread(params: WorkerParams) -> Result<()>278 async fn worker_thread(params: WorkerParams) -> Result<()> {
279 let backend_socket = into_tokio_udp_socket(params.backend_socket)?;
280 let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?;
281 let config = params.config;
282 let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>();
283 let mut command_rx = params.command_rx;
284 let mut clients = params.clients;
285 let mut frontend_buf = [0; 65535];
286 let mut backend_buf = [0; 16384];
287 let mut delay_queries_buffer: Vec<Vec<u8>> = vec![];
288 let mut queries_received = 0;
289
290 debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket);
291
292 loop {
293 let timeout = clients
294 .iter_mut()
295 .filter_map(|(_, c)| c.timeout())
296 .min()
297 .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
298
299 tokio::select! {
300 _ = tokio::time::sleep(timeout) => {
301 debug!("timeout");
302 for (_, client) in clients.iter_mut() {
303 // If no timeout has occurred it does nothing.
304 client.on_timeout();
305
306 let connection_id = client.connection_id().clone();
307 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
308 }
309 }
310
311 Ok((len, peer)) = frontend_socket.recv_from(&mut frontend_buf) => {
312 debug!("Got {} bytes from {}", len, peer);
313
314 // Parse QUIC packet.
315 let pkt_buf = &mut frontend_buf[..len];
316 let hdr = match quiche::Header::from_slice(pkt_buf, CONN_ID_LEN) {
317 Ok(v) => v,
318 Err(e) => {
319 error!("Failed to parse QUIC header: {:?}", e);
320 continue;
321 }
322 };
323 debug!("Got QUIC packet: {:?}", hdr);
324
325 let local = frontend_socket.local_addr()?;
326 let client = match clients.get_or_create(&hdr, &peer, &local) {
327 Ok(v) => v,
328 Err(e) => {
329 error!("Failed to get the client by the hdr {:?}: {}", hdr, e);
330 continue;
331 }
332 };
333 debug!("Got client: {:?}", client);
334
335 match client.handle_frontend_message(pkt_buf, &local) {
336 Ok(v) if !v.is_empty() => {
337 delay_queries_buffer.push(v);
338 queries_received += 1;
339 }
340 Err(e) => {
341 error!("Failed to process QUIC packet: {}", e);
342 continue;
343 }
344 _ => {}
345 }
346
347 if delay_queries_buffer.len() >= config.lock().unwrap().delay_queries as usize {
348 for query in delay_queries_buffer.drain(..) {
349 debug!("sending {} bytes to backend", query.len());
350 backend_socket.send(&query).await?;
351 }
352 }
353
354 let connection_id = client.connection_id().clone();
355 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
356 }
357
358 Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => {
359 debug!("Got {} bytes from {}", len, src);
360 if len < DNS_HEADER_SIZE {
361 error!("Received insufficient bytes for DNS header");
362 continue;
363 }
364
365 let query_id = [backend_buf[0], backend_buf[1]];
366 for (_, client) in clients.iter_mut() {
367 if client.is_waiting_for_query(&query_id) {
368 let reset_stream_id = config.lock().unwrap().reset_stream_id;
369 if let Err(e) = client.handle_backend_message(&backend_buf[..len], reset_stream_id) {
370 error!("Failed to handle message from backend: {}", e);
371 }
372 let connection_id = client.connection_id().clone();
373 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
374
375 // It's a bug if more than one client is waiting for this query.
376 break;
377 }
378 }
379 }
380
381 Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => {
382 match command {
383 InternalCommand::MaybeWrite {connection_id} => {
384 if let Some(client) = clients.get_mut(&connection_id) {
385 while let Ok(v) = client.flush_egress() {
386 let addr = client.addr();
387 debug!("Sending {} bytes to client {}", v.len(), addr);
388 if let Err(e) = frontend_socket.send_to(&v, addr).await {
389 error!("Failed to send packet to {:?}: {:?}", client, e);
390 }
391 }
392 client.process_pending_answers()?;
393 }
394 }
395 }
396 }
397 Some(command) = command_rx.recv() => {
398 debug!("ControlCommand: {:?}", command);
399 match command {
400 ControlCommand::Stats {resp} => {
401 let stats = Stats {
402 queries_received,
403 connections_accepted: clients.len() as u32,
404 alive_connections: clients.iter().filter(|(_, client)| client.is_alive()).count() as u32,
405 resumed_connections: clients.iter().filter(|(_, client)| client.is_resumed()).count() as u32,
406 early_data_connections: clients.iter().filter(|(_, client)| client.handled_early_data()).count() as u32,
407 };
408 if let Err(e) = resp.send(stats) {
409 error!("Failed to send ControlCommand::Stats response: {:?}", e);
410 }
411 }
412 ControlCommand::StatsClearQueries => queries_received = 0,
413 ControlCommand::CloseConnection => {
414 for (_, client) in clients.iter_mut() {
415 client.close();
416 event_tx.send(InternalCommand::MaybeWrite { connection_id: client.connection_id().clone() })?;
417 }
418 }
419 }
420 }
421 }
422 }
423 }
424
create_quiche_config( certificate: String, private_key: String, config: Arc<Mutex<Config>>, ) -> Result<quiche::Config>425 fn create_quiche_config(
426 certificate: String,
427 private_key: String,
428 config: Arc<Mutex<Config>>,
429 ) -> Result<quiche::Config> {
430 let mut quiche_config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
431
432 // Use pipe as a file path for Quiche to read the certificate and the private key.
433 let (rd, mut wr) = build_pipe()?;
434 let handle = std::thread::spawn(move || {
435 wr.write_all(certificate.as_bytes()).expect("Failed to write to pipe");
436 });
437 let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
438 quiche_config.load_cert_chain_from_pem_file(&filepath)?;
439 handle.join().unwrap();
440
441 let (rd, mut wr) = build_pipe()?;
442 let handle = std::thread::spawn(move || {
443 wr.write_all(private_key.as_bytes()).expect("Failed to write to pipe");
444 });
445 let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
446 quiche_config.load_priv_key_from_pem_file(&filepath)?;
447 handle.join().unwrap();
448
449 quiche_config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)?;
450 quiche_config.set_max_idle_timeout(config.lock().unwrap().max_idle_timeout);
451 quiche_config.set_max_recv_udp_payload_size(MAX_UDP_PAYLOAD_SIZE);
452
453 let max_buffer_size = config.lock().unwrap().max_buffer_size;
454 quiche_config.set_initial_max_data(max_buffer_size);
455 quiche_config.set_initial_max_stream_data_bidi_local(max_buffer_size);
456 quiche_config.set_initial_max_stream_data_bidi_remote(max_buffer_size);
457 quiche_config.set_initial_max_stream_data_uni(max_buffer_size);
458
459 quiche_config.set_initial_max_streams_bidi(config.lock().unwrap().max_streams_bidi);
460 quiche_config.set_initial_max_streams_uni(100);
461 quiche_config.set_disable_active_migration(true);
462 quiche_config.enable_early_data();
463
464 Ok(quiche_config)
465 }
466
into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket>467 fn into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket> {
468 match UdpSocket::from_std(socket) {
469 Ok(v) => Ok(v),
470 Err(e) => {
471 error!("into_tokio_udp_socket failed: {}", e);
472 bail!("into_tokio_udp_socket failed: {}", e)
473 }
474 }
475 }
476
build_pipe() -> Result<(File, File)>477 fn build_pipe() -> Result<(File, File)> {
478 let mut fds = [0, 0];
479 // SAFETY: The pointer we pass to `pipe` must be valid because it comes from a reference. The
480 // file descriptors it returns must be valid and open, so they are safe to pass to
481 // `File::from_raw_fd`.
482 unsafe {
483 if libc::pipe(fds.as_mut_ptr()) == 0 {
484 return Ok((File::from_raw_fd(fds[0]), File::from_raw_fd(fds[1])));
485 }
486 }
487 Err(anyhow::Error::new(std::io::Error::last_os_error()).context("build_pipe failed"))
488 }
489
490 // Can retry to bind the socket address if it is in use.
bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket>491 fn bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket> {
492 for _ in 0..3 {
493 match std::net::UdpSocket::bind(addr) {
494 Ok(socket) => return Ok(socket),
495 Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
496 warn!("Binding socket address {} that is in use. Try again", addr);
497 std::thread::sleep(Duration::from_millis(50));
498 }
499 Err(e) => return Err(anyhow::anyhow!(e)),
500 }
501 }
502 Err(anyhow::anyhow!(std::io::Error::last_os_error()))
503 }
504