1 // Copyright 2024 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Copied from ChromiumOS with relicensing:
16 // src/platform2/vm_tools/chunnel/src/forwarder.rs
17 
18 //! This module contains forwarding mechanism between stream sockets.
19 
20 use std::fmt;
21 use std::io::{self, Read, Write};
22 use std::result;
23 
24 use crate::stream::StreamSocket;
25 
26 // This was picked arbitrarily. crosvm doesn't yet use VIRTIO_NET_F_MTU, so there's no reason to
27 // opt for massive 65535 byte frames.
28 const MAX_FRAME_SIZE: usize = 8192;
29 
30 /// Errors that can be encountered by a ForwarderSession.
31 #[remain::sorted]
32 #[derive(Debug)]
33 pub enum ForwarderError {
34     /// An io::Error was encountered while reading from a stream.
35     ReadFromStream(io::Error),
36     /// An io::Error was encountered while shutting down writes on a stream.
37     ShutDownStream(io::Error),
38     /// An io::Error was encountered while writing to a stream.
39     WriteToStream(io::Error),
40 }
41 
42 type Result<T> = result::Result<T, ForwarderError>;
43 
44 impl fmt::Display for ForwarderError {
45     #[remain::check]
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result46     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47         use self::ForwarderError::*;
48 
49         #[remain::sorted]
50         match self {
51             ReadFromStream(e) => write!(f, "failed to read from stream: {}", e),
52             ShutDownStream(e) => write!(f, "failed to shut down stream: {}", e),
53             WriteToStream(e) => write!(f, "failed to write to stream: {}", e),
54         }
55     }
56 }
57 
58 /// A ForwarderSession owns stream sockets that it forwards traffic between.
59 pub struct ForwarderSession {
60     local: StreamSocket,
61     remote: StreamSocket,
62 }
63 
forward(from_stream: &mut StreamSocket, to_stream: &mut StreamSocket) -> Result<bool>64 fn forward(from_stream: &mut StreamSocket, to_stream: &mut StreamSocket) -> Result<bool> {
65     let mut buf = [0u8; MAX_FRAME_SIZE];
66 
67     let count = from_stream.read(&mut buf).map_err(ForwarderError::ReadFromStream)?;
68     if count == 0 {
69         to_stream.shut_down_write().map_err(ForwarderError::ShutDownStream)?;
70         return Ok(true);
71     }
72 
73     to_stream.write_all(&buf[..count]).map_err(ForwarderError::WriteToStream)?;
74     Ok(false)
75 }
76 
77 impl ForwarderSession {
78     /// Creates a forwarder session from a local and remote stream socket.
new(local: StreamSocket, remote: StreamSocket) -> Self79     pub fn new(local: StreamSocket, remote: StreamSocket) -> Self {
80         ForwarderSession { local, remote }
81     }
82 
83     /// Forwards traffic from the local socket to the remote socket.
84     /// Returns true if the local socket has reached EOF and the
85     /// remote socket has been shut down for further writes.
forward_from_local(&mut self) -> Result<bool>86     pub fn forward_from_local(&mut self) -> Result<bool> {
87         forward(&mut self.local, &mut self.remote)
88     }
89 
90     /// Forwards traffic from the remote socket to the local socket.
91     /// Returns true if the remote socket has reached EOF and the
92     /// local socket has been shut down for further writes.
forward_from_remote(&mut self) -> Result<bool>93     pub fn forward_from_remote(&mut self) -> Result<bool> {
94         forward(&mut self.remote, &mut self.local)
95     }
96 
97     /// Returns a reference to the local stream socket.
local_stream(&self) -> &StreamSocket98     pub fn local_stream(&self) -> &StreamSocket {
99         &self.local
100     }
101 
102     /// Returns a reference to the remote stream socket.
remote_stream(&self) -> &StreamSocket103     pub fn remote_stream(&self) -> &StreamSocket {
104         &self.remote
105     }
106 
107     /// Returns true if both sockets are completely shut down and the session can be dropped.
is_shut_down(&self) -> bool108     pub fn is_shut_down(&self) -> bool {
109         self.local.is_shut_down() && self.remote.is_shut_down()
110     }
111 }
112 
113 #[cfg(test)]
114 mod tests {
115     use super::*;
116     use std::io::{Read, Write};
117     use std::net::Shutdown;
118     use std::os::unix::net::UnixStream;
119 
120     #[test]
forward_unix()121     fn forward_unix() {
122         // Local streams.
123         let (mut london, folkestone) = UnixStream::pair().unwrap();
124         // Remote streams.
125         let (coquelles, mut paris) = UnixStream::pair().unwrap();
126 
127         // Connect the local and remote sockets via the chunnel.
128         let mut forwarder = ForwarderSession::new(folkestone.into(), coquelles.into());
129 
130         // Put some traffic in from London.
131         let greeting = b"hello";
132         london.write_all(greeting).unwrap();
133 
134         // Expect forwarding from the local end not to have reached EOF.
135         assert!(!forwarder.forward_from_local().unwrap());
136         let mut salutation = [0u8; 8];
137         let count = paris.read(&mut salutation).unwrap();
138         assert_eq!(greeting.len(), count);
139         assert_eq!(greeting, &salutation[..count]);
140 
141         // Shut the local socket down. The forwarder should detect this and perform a shutdown,
142         // which will manifest as an EOF when reading.
143         london.shutdown(Shutdown::Write).unwrap();
144         assert!(forwarder.forward_from_local().unwrap());
145         assert_eq!(paris.read(&mut salutation).unwrap(), 0);
146 
147         // Don't consider the forwarder shut down until both ends are.
148         assert!(!forwarder.is_shut_down());
149 
150         // Forward traffic from the remote end.
151         let salutation = b"bonjour";
152         paris.write_all(salutation).unwrap();
153 
154         // Expect forwarding from the remote end not to have reached EOF.
155         assert!(!forwarder.forward_from_remote().unwrap());
156         let mut greeting = [0u8; 8];
157         let count = london.read(&mut greeting).unwrap();
158         assert_eq!(salutation.len(), count);
159         assert_eq!(salutation, &greeting[..count]);
160 
161         // Shut the remote socket down. The forwarder should detect this and perform a shutdown,
162         // which will manifest as an EOF when reading.
163         paris.shutdown(Shutdown::Write).unwrap();
164         assert!(forwarder.forward_from_remote().unwrap());
165         assert_eq!(london.read(&mut greeting).unwrap(), 0);
166 
167         // The forwarder should now be considered shut down.
168         assert!(forwarder.is_shut_down());
169     }
170 }
171