1 /*
2  * Copyright (C) 2021, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 //! DoH server frontend.
18 
19 use super::client::{ClientMap, ConnectionID, CONN_ID_LEN, DNS_HEADER_SIZE, MAX_UDP_PAYLOAD_SIZE};
20 use super::config::{Config, QUICHE_IDLE_TIMEOUT_MS};
21 use super::stats::Stats;
22 use anyhow::{bail, ensure, Result};
23 use log::{debug, error, warn};
24 use std::fs::File;
25 use std::io::Write;
26 use std::os::unix::io::{AsRawFd, FromRawFd};
27 use std::sync::{Arc, LazyLock, Mutex};
28 use std::time::Duration;
29 use tokio::net::UdpSocket;
30 use tokio::runtime::{Builder, Runtime};
31 use tokio::sync::{mpsc, oneshot};
32 use tokio::task::JoinHandle;
33 
34 static RUNTIME_STATIC: LazyLock<Arc<Runtime>> = LazyLock::new(|| {
35     Arc::new(
36         Builder::new_multi_thread()
37             .worker_threads(1)
38             .enable_all()
39             .thread_name("DohFrontend")
40             .build()
41             .expect("Failed to create tokio runtime"),
42     )
43 });
44 
45 /// Command used by worker_thread itself.
46 #[derive(Debug)]
47 enum InternalCommand {
48     MaybeWrite { connection_id: ConnectionID },
49 }
50 
51 /// Commands that DohFrontend to ask its worker_thread for.
52 #[derive(Debug)]
53 enum ControlCommand {
54     Stats { resp: oneshot::Sender<Stats> },
55     StatsClearQueries,
56     CloseConnection,
57 }
58 
59 /// Frontend object.
60 #[derive(Debug)]
61 pub struct DohFrontend {
62     // Socket address the frontend listens to.
63     listen_socket_addr: std::net::SocketAddr,
64 
65     // Socket address the backend listens to.
66     backend_socket_addr: std::net::SocketAddr,
67 
68     /// The content of the certificate.
69     certificate: String,
70 
71     /// The content of the private key.
72     private_key: String,
73 
74     // The thread listening to frontend socket and backend socket
75     // and processing the messages.
76     worker_thread: Option<JoinHandle<Result<()>>>,
77 
78     // Custom runtime configuration to control the behavior of the worker thread.
79     // It's shared with the worker thread.
80     // TODO: use channel to update worker_thread configuration.
81     config: Arc<Mutex<Config>>,
82 
83     // Caches the latest stats so that the stats remains after worker_thread stops.
84     latest_stats: Stats,
85 
86     // It is wrapped as Option because the channel is not created in DohFrontend construction.
87     command_tx: Option<mpsc::UnboundedSender<ControlCommand>>,
88 }
89 
90 /// The parameters passed to the worker thread.
91 struct WorkerParams {
92     frontend_socket: std::net::UdpSocket,
93     backend_socket: std::net::UdpSocket,
94     clients: ClientMap,
95     config: Arc<Mutex<Config>>,
96     command_rx: mpsc::UnboundedReceiver<ControlCommand>,
97 }
98 
99 impl DohFrontend {
new( listen: std::net::SocketAddr, backend: std::net::SocketAddr, ) -> Result<Box<DohFrontend>>100     pub fn new(
101         listen: std::net::SocketAddr,
102         backend: std::net::SocketAddr,
103     ) -> Result<Box<DohFrontend>> {
104         let doh = Box::new(DohFrontend {
105             listen_socket_addr: listen,
106             backend_socket_addr: backend,
107             certificate: String::new(),
108             private_key: String::new(),
109             worker_thread: None,
110             config: Arc::new(Mutex::new(Config::new())),
111             latest_stats: Stats::new(),
112             command_tx: None,
113         });
114         debug!("DohFrontend created: {:?}", doh);
115         Ok(doh)
116     }
117 
start(&mut self) -> Result<()>118     pub fn start(&mut self) -> Result<()> {
119         ensure!(self.worker_thread.is_none(), "Worker thread has been running");
120         ensure!(!self.certificate.is_empty(), "certificate is empty");
121         ensure!(!self.private_key.is_empty(), "private_key is empty");
122 
123         // Doing error handling here is much simpler.
124         let params = match self.init_worker_thread_params() {
125             Ok(v) => v,
126             Err(e) => return Err(e.context("init_worker_thread_params failed")),
127         };
128 
129         self.worker_thread = Some(RUNTIME_STATIC.spawn(worker_thread(params)));
130         Ok(())
131     }
132 
stop(&mut self) -> Result<()>133     pub fn stop(&mut self) -> Result<()> {
134         debug!("DohFrontend: stopping: {:?}", self);
135         if let Some(worker_thread) = self.worker_thread.take() {
136             // Update latest_stats before stopping worker_thread.
137             let _ = self.request_stats();
138 
139             self.command_tx.as_ref().unwrap().send(ControlCommand::CloseConnection)?;
140             if let Err(e) = self.wait_for_connections_closed() {
141                 warn!("wait_for_connections_closed failed: {}", e);
142             }
143 
144             worker_thread.abort();
145             RUNTIME_STATIC.block_on(async {
146                 debug!("worker_thread result: {:?}", worker_thread.await);
147             })
148         }
149 
150         debug!("DohFrontend: stopped: {:?}", self);
151         Ok(())
152     }
153 
set_certificate(&mut self, certificate: &str) -> Result<()>154     pub fn set_certificate(&mut self, certificate: &str) -> Result<()> {
155         self.certificate = certificate.to_string();
156         Ok(())
157     }
158 
set_private_key(&mut self, private_key: &str) -> Result<()>159     pub fn set_private_key(&mut self, private_key: &str) -> Result<()> {
160         self.private_key = private_key.to_string();
161         Ok(())
162     }
163 
set_delay_queries(&self, value: i32) -> Result<()>164     pub fn set_delay_queries(&self, value: i32) -> Result<()> {
165         self.config.lock().unwrap().delay_queries = value;
166         Ok(())
167     }
168 
set_max_idle_timeout(&self, value: u64) -> Result<()>169     pub fn set_max_idle_timeout(&self, value: u64) -> Result<()> {
170         self.config.lock().unwrap().max_idle_timeout = value;
171         Ok(())
172     }
173 
set_max_buffer_size(&self, value: u64) -> Result<()>174     pub fn set_max_buffer_size(&self, value: u64) -> Result<()> {
175         self.config.lock().unwrap().max_buffer_size = value;
176         Ok(())
177     }
178 
set_max_streams_bidi(&self, value: u64) -> Result<()>179     pub fn set_max_streams_bidi(&self, value: u64) -> Result<()> {
180         self.config.lock().unwrap().max_streams_bidi = value;
181         Ok(())
182     }
183 
block_sending(&self, value: bool) -> Result<()>184     pub fn block_sending(&self, value: bool) -> Result<()> {
185         self.config.lock().unwrap().block_sending = value;
186         Ok(())
187     }
188 
set_reset_stream_id(&self, value: u64) -> Result<()>189     pub fn set_reset_stream_id(&self, value: u64) -> Result<()> {
190         self.config.lock().unwrap().reset_stream_id = Some(value);
191         Ok(())
192     }
193 
request_stats(&mut self) -> Result<Stats>194     pub fn request_stats(&mut self) -> Result<Stats> {
195         ensure!(
196             self.command_tx.is_some(),
197             "command_tx is None because worker thread not yet initialized"
198         );
199         let command_tx = self.command_tx.as_ref().unwrap();
200 
201         if command_tx.is_closed() {
202             return Ok(self.latest_stats.clone());
203         }
204 
205         let (resp_tx, resp_rx) = oneshot::channel();
206         command_tx.send(ControlCommand::Stats { resp: resp_tx })?;
207 
208         match RUNTIME_STATIC
209             .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await })
210         {
211             Ok(v) => match v {
212                 Ok(stats) => {
213                     self.latest_stats = stats.clone();
214                     Ok(stats)
215                 }
216                 Err(e) => bail!(e),
217             },
218             Err(e) => bail!(e),
219         }
220     }
221 
stats_clear_queries(&self) -> Result<()>222     pub fn stats_clear_queries(&self) -> Result<()> {
223         ensure!(
224             self.command_tx.is_some(),
225             "command_tx is None because worker thread not yet initialized"
226         );
227         return self
228             .command_tx
229             .as_ref()
230             .unwrap()
231             .send(ControlCommand::StatsClearQueries)
232             .or_else(|e| bail!(e));
233     }
234 
init_worker_thread_params(&mut self) -> Result<WorkerParams>235     fn init_worker_thread_params(&mut self) -> Result<WorkerParams> {
236         let bind_addr =
237             if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
238         let backend_socket = std::net::UdpSocket::bind(bind_addr)?;
239         backend_socket.connect(self.backend_socket_addr)?;
240         backend_socket.set_nonblocking(true)?;
241 
242         let frontend_socket = bind_udp_socket_retry(self.listen_socket_addr)?;
243         frontend_socket.set_nonblocking(true)?;
244 
245         let clients = ClientMap::new(create_quiche_config(
246             self.certificate.to_string(),
247             self.private_key.to_string(),
248             self.config.clone(),
249         )?)?;
250 
251         let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>();
252         self.command_tx = Some(command_tx);
253 
254         Ok(WorkerParams {
255             frontend_socket,
256             backend_socket,
257             clients,
258             config: self.config.clone(),
259             command_rx,
260         })
261     }
262 
wait_for_connections_closed(&mut self) -> Result<()>263     fn wait_for_connections_closed(&mut self) -> Result<()> {
264         for _ in 0..3 {
265             std::thread::sleep(Duration::from_millis(50));
266             match self.request_stats() {
267                 Ok(stats) if stats.alive_connections == 0 => return Ok(()),
268                 Ok(_) => (),
269 
270                 // The worker thread is down. No connection is alive.
271                 Err(_) => return Ok(()),
272             }
273         }
274         bail!("Some connections still alive")
275     }
276 }
277 
worker_thread(params: WorkerParams) -> Result<()>278 async fn worker_thread(params: WorkerParams) -> Result<()> {
279     let backend_socket = into_tokio_udp_socket(params.backend_socket)?;
280     let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?;
281     let config = params.config;
282     let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>();
283     let mut command_rx = params.command_rx;
284     let mut clients = params.clients;
285     let mut frontend_buf = [0; 65535];
286     let mut backend_buf = [0; 16384];
287     let mut delay_queries_buffer: Vec<Vec<u8>> = vec![];
288     let mut queries_received = 0;
289 
290     debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket);
291 
292     loop {
293         let timeout = clients
294             .iter_mut()
295             .filter_map(|(_, c)| c.timeout())
296             .min()
297             .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
298 
299         tokio::select! {
300             _ = tokio::time::sleep(timeout) => {
301                 debug!("timeout");
302                 for (_, client) in clients.iter_mut() {
303                     // If no timeout has occurred it does nothing.
304                     client.on_timeout();
305 
306                     let connection_id = client.connection_id().clone();
307                     event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
308                 }
309             }
310 
311             Ok((len, peer)) = frontend_socket.recv_from(&mut frontend_buf) => {
312                 debug!("Got {} bytes from {}", len, peer);
313 
314                 // Parse QUIC packet.
315                 let pkt_buf = &mut frontend_buf[..len];
316                 let hdr = match quiche::Header::from_slice(pkt_buf, CONN_ID_LEN) {
317                     Ok(v) => v,
318                     Err(e) => {
319                         error!("Failed to parse QUIC header: {:?}", e);
320                         continue;
321                     }
322                 };
323                 debug!("Got QUIC packet: {:?}", hdr);
324 
325                 let local = frontend_socket.local_addr()?;
326                 let client = match clients.get_or_create(&hdr, &peer, &local) {
327                     Ok(v) => v,
328                     Err(e) => {
329                         error!("Failed to get the client by the hdr {:?}: {}", hdr, e);
330                         continue;
331                     }
332                 };
333                 debug!("Got client: {:?}", client);
334 
335                 match client.handle_frontend_message(pkt_buf, &local) {
336                     Ok(v) if !v.is_empty() => {
337                         delay_queries_buffer.push(v);
338                         queries_received += 1;
339                     }
340                     Err(e) => {
341                         error!("Failed to process QUIC packet: {}", e);
342                         continue;
343                     }
344                     _ => {}
345                 }
346 
347                 if delay_queries_buffer.len() >= config.lock().unwrap().delay_queries as usize {
348                     for query in delay_queries_buffer.drain(..) {
349                         debug!("sending {} bytes to backend", query.len());
350                         backend_socket.send(&query).await?;
351                     }
352                 }
353 
354                 let connection_id = client.connection_id().clone();
355                 event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
356             }
357 
358             Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => {
359                 debug!("Got {} bytes from {}", len, src);
360                 if len < DNS_HEADER_SIZE {
361                     error!("Received insufficient bytes for DNS header");
362                     continue;
363                 }
364 
365                 let query_id = [backend_buf[0], backend_buf[1]];
366                 for (_, client) in clients.iter_mut() {
367                     if client.is_waiting_for_query(&query_id) {
368                         let reset_stream_id = config.lock().unwrap().reset_stream_id;
369                         if let Err(e) = client.handle_backend_message(&backend_buf[..len], reset_stream_id) {
370                             error!("Failed to handle message from backend: {}", e);
371                         }
372                         let connection_id = client.connection_id().clone();
373                         event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
374 
375                         // It's a bug if more than one client is waiting for this query.
376                         break;
377                     }
378                 }
379             }
380 
381             Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => {
382                 match command {
383                     InternalCommand::MaybeWrite {connection_id} => {
384                         if let Some(client) = clients.get_mut(&connection_id) {
385                             while let Ok(v) = client.flush_egress() {
386                                 let addr = client.addr();
387                                 debug!("Sending {} bytes to client {}", v.len(), addr);
388                                 if let Err(e) = frontend_socket.send_to(&v, addr).await {
389                                     error!("Failed to send packet to {:?}: {:?}", client, e);
390                                 }
391                             }
392                             client.process_pending_answers()?;
393                         }
394                     }
395                 }
396             }
397             Some(command) = command_rx.recv() => {
398                 debug!("ControlCommand: {:?}", command);
399                 match command {
400                     ControlCommand::Stats {resp} => {
401                         let stats = Stats {
402                             queries_received,
403                             connections_accepted: clients.len() as u32,
404                             alive_connections: clients.iter().filter(|(_, client)| client.is_alive()).count() as u32,
405                             resumed_connections: clients.iter().filter(|(_, client)| client.is_resumed()).count() as u32,
406                             early_data_connections: clients.iter().filter(|(_, client)| client.handled_early_data()).count() as u32,
407                         };
408                         if let Err(e) = resp.send(stats) {
409                             error!("Failed to send ControlCommand::Stats response: {:?}", e);
410                         }
411                     }
412                     ControlCommand::StatsClearQueries => queries_received = 0,
413                     ControlCommand::CloseConnection => {
414                         for (_, client) in clients.iter_mut() {
415                             client.close();
416                             event_tx.send(InternalCommand::MaybeWrite { connection_id: client.connection_id().clone() })?;
417                         }
418                     }
419                 }
420             }
421         }
422     }
423 }
424 
create_quiche_config( certificate: String, private_key: String, config: Arc<Mutex<Config>>, ) -> Result<quiche::Config>425 fn create_quiche_config(
426     certificate: String,
427     private_key: String,
428     config: Arc<Mutex<Config>>,
429 ) -> Result<quiche::Config> {
430     let mut quiche_config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
431 
432     // Use pipe as a file path for Quiche to read the certificate and the private key.
433     let (rd, mut wr) = build_pipe()?;
434     let handle = std::thread::spawn(move || {
435         wr.write_all(certificate.as_bytes()).expect("Failed to write to pipe");
436     });
437     let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
438     quiche_config.load_cert_chain_from_pem_file(&filepath)?;
439     handle.join().unwrap();
440 
441     let (rd, mut wr) = build_pipe()?;
442     let handle = std::thread::spawn(move || {
443         wr.write_all(private_key.as_bytes()).expect("Failed to write to pipe");
444     });
445     let filepath = format!("/proc/self/fd/{}", rd.as_raw_fd());
446     quiche_config.load_priv_key_from_pem_file(&filepath)?;
447     handle.join().unwrap();
448 
449     quiche_config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)?;
450     quiche_config.set_max_idle_timeout(config.lock().unwrap().max_idle_timeout);
451     quiche_config.set_max_recv_udp_payload_size(MAX_UDP_PAYLOAD_SIZE);
452 
453     let max_buffer_size = config.lock().unwrap().max_buffer_size;
454     quiche_config.set_initial_max_data(max_buffer_size);
455     quiche_config.set_initial_max_stream_data_bidi_local(max_buffer_size);
456     quiche_config.set_initial_max_stream_data_bidi_remote(max_buffer_size);
457     quiche_config.set_initial_max_stream_data_uni(max_buffer_size);
458 
459     quiche_config.set_initial_max_streams_bidi(config.lock().unwrap().max_streams_bidi);
460     quiche_config.set_initial_max_streams_uni(100);
461     quiche_config.set_disable_active_migration(true);
462     quiche_config.enable_early_data();
463 
464     Ok(quiche_config)
465 }
466 
into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket>467 fn into_tokio_udp_socket(socket: std::net::UdpSocket) -> Result<UdpSocket> {
468     match UdpSocket::from_std(socket) {
469         Ok(v) => Ok(v),
470         Err(e) => {
471             error!("into_tokio_udp_socket failed: {}", e);
472             bail!("into_tokio_udp_socket failed: {}", e)
473         }
474     }
475 }
476 
build_pipe() -> Result<(File, File)>477 fn build_pipe() -> Result<(File, File)> {
478     let mut fds = [0, 0];
479     // SAFETY: The pointer we pass to `pipe` must be valid because it comes from a reference. The
480     // file descriptors it returns must be valid and open, so they are safe to pass to
481     // `File::from_raw_fd`.
482     unsafe {
483         if libc::pipe(fds.as_mut_ptr()) == 0 {
484             return Ok((File::from_raw_fd(fds[0]), File::from_raw_fd(fds[1])));
485         }
486     }
487     Err(anyhow::Error::new(std::io::Error::last_os_error()).context("build_pipe failed"))
488 }
489 
490 // Can retry to bind the socket address if it is in use.
bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket>491 fn bind_udp_socket_retry(addr: std::net::SocketAddr) -> Result<std::net::UdpSocket> {
492     for _ in 0..3 {
493         match std::net::UdpSocket::bind(addr) {
494             Ok(socket) => return Ok(socket),
495             Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
496                 warn!("Binding socket address {} that is in use. Try again", addr);
497                 std::thread::sleep(Duration::from_millis(50));
498             }
499             Err(e) => return Err(anyhow::anyhow!(e)),
500         }
501     }
502     Err(anyhow::anyhow!(std::io::Error::last_os_error()))
503 }
504