1 // Copyright 2024 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Copied from ChromiumOS with relicensing:
16 // src/platform2/vm_tools/chunnel/src/forwarder.rs
17
18 //! This module contains forwarding mechanism between stream sockets.
19
20 use std::fmt;
21 use std::io::{self, Read, Write};
22 use std::result;
23
24 use crate::stream::StreamSocket;
25
26 // This was picked arbitrarily. crosvm doesn't yet use VIRTIO_NET_F_MTU, so there's no reason to
27 // opt for massive 65535 byte frames.
28 const MAX_FRAME_SIZE: usize = 8192;
29
30 /// Errors that can be encountered by a ForwarderSession.
31 #[remain::sorted]
32 #[derive(Debug)]
33 pub enum ForwarderError {
34 /// An io::Error was encountered while reading from a stream.
35 ReadFromStream(io::Error),
36 /// An io::Error was encountered while shutting down writes on a stream.
37 ShutDownStream(io::Error),
38 /// An io::Error was encountered while writing to a stream.
39 WriteToStream(io::Error),
40 }
41
42 type Result<T> = result::Result<T, ForwarderError>;
43
44 impl fmt::Display for ForwarderError {
45 #[remain::check]
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result46 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47 use self::ForwarderError::*;
48
49 #[remain::sorted]
50 match self {
51 ReadFromStream(e) => write!(f, "failed to read from stream: {}", e),
52 ShutDownStream(e) => write!(f, "failed to shut down stream: {}", e),
53 WriteToStream(e) => write!(f, "failed to write to stream: {}", e),
54 }
55 }
56 }
57
58 /// A ForwarderSession owns stream sockets that it forwards traffic between.
59 pub struct ForwarderSession {
60 local: StreamSocket,
61 remote: StreamSocket,
62 }
63
forward(from_stream: &mut StreamSocket, to_stream: &mut StreamSocket) -> Result<bool>64 fn forward(from_stream: &mut StreamSocket, to_stream: &mut StreamSocket) -> Result<bool> {
65 let mut buf = [0u8; MAX_FRAME_SIZE];
66
67 let count = from_stream.read(&mut buf).map_err(ForwarderError::ReadFromStream)?;
68 if count == 0 {
69 to_stream.shut_down_write().map_err(ForwarderError::ShutDownStream)?;
70 return Ok(true);
71 }
72
73 to_stream.write_all(&buf[..count]).map_err(ForwarderError::WriteToStream)?;
74 Ok(false)
75 }
76
77 impl ForwarderSession {
78 /// Creates a forwarder session from a local and remote stream socket.
new(local: StreamSocket, remote: StreamSocket) -> Self79 pub fn new(local: StreamSocket, remote: StreamSocket) -> Self {
80 ForwarderSession { local, remote }
81 }
82
83 /// Forwards traffic from the local socket to the remote socket.
84 /// Returns true if the local socket has reached EOF and the
85 /// remote socket has been shut down for further writes.
forward_from_local(&mut self) -> Result<bool>86 pub fn forward_from_local(&mut self) -> Result<bool> {
87 forward(&mut self.local, &mut self.remote)
88 }
89
90 /// Forwards traffic from the remote socket to the local socket.
91 /// Returns true if the remote socket has reached EOF and the
92 /// local socket has been shut down for further writes.
forward_from_remote(&mut self) -> Result<bool>93 pub fn forward_from_remote(&mut self) -> Result<bool> {
94 forward(&mut self.remote, &mut self.local)
95 }
96
97 /// Returns a reference to the local stream socket.
local_stream(&self) -> &StreamSocket98 pub fn local_stream(&self) -> &StreamSocket {
99 &self.local
100 }
101
102 /// Returns a reference to the remote stream socket.
remote_stream(&self) -> &StreamSocket103 pub fn remote_stream(&self) -> &StreamSocket {
104 &self.remote
105 }
106
107 /// Returns true if both sockets are completely shut down and the session can be dropped.
is_shut_down(&self) -> bool108 pub fn is_shut_down(&self) -> bool {
109 self.local.is_shut_down() && self.remote.is_shut_down()
110 }
111 }
112
113 #[cfg(test)]
114 mod tests {
115 use super::*;
116 use std::io::{Read, Write};
117 use std::net::Shutdown;
118 use std::os::unix::net::UnixStream;
119
120 #[test]
forward_unix()121 fn forward_unix() {
122 // Local streams.
123 let (mut london, folkestone) = UnixStream::pair().unwrap();
124 // Remote streams.
125 let (coquelles, mut paris) = UnixStream::pair().unwrap();
126
127 // Connect the local and remote sockets via the chunnel.
128 let mut forwarder = ForwarderSession::new(folkestone.into(), coquelles.into());
129
130 // Put some traffic in from London.
131 let greeting = b"hello";
132 london.write_all(greeting).unwrap();
133
134 // Expect forwarding from the local end not to have reached EOF.
135 assert!(!forwarder.forward_from_local().unwrap());
136 let mut salutation = [0u8; 8];
137 let count = paris.read(&mut salutation).unwrap();
138 assert_eq!(greeting.len(), count);
139 assert_eq!(greeting, &salutation[..count]);
140
141 // Shut the local socket down. The forwarder should detect this and perform a shutdown,
142 // which will manifest as an EOF when reading.
143 london.shutdown(Shutdown::Write).unwrap();
144 assert!(forwarder.forward_from_local().unwrap());
145 assert_eq!(paris.read(&mut salutation).unwrap(), 0);
146
147 // Don't consider the forwarder shut down until both ends are.
148 assert!(!forwarder.is_shut_down());
149
150 // Forward traffic from the remote end.
151 let salutation = b"bonjour";
152 paris.write_all(salutation).unwrap();
153
154 // Expect forwarding from the remote end not to have reached EOF.
155 assert!(!forwarder.forward_from_remote().unwrap());
156 let mut greeting = [0u8; 8];
157 let count = london.read(&mut greeting).unwrap();
158 assert_eq!(salutation.len(), count);
159 assert_eq!(salutation, &greeting[..count]);
160
161 // Shut the remote socket down. The forwarder should detect this and perform a shutdown,
162 // which will manifest as an EOF when reading.
163 paris.shutdown(Shutdown::Write).unwrap();
164 assert!(forwarder.forward_from_remote().unwrap());
165 assert_eq!(london.read(&mut greeting).unwrap(), 0);
166
167 // The forwarder should now be considered shut down.
168 assert!(forwarder.is_shut_down());
169 }
170 }
171