1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use bytes::Bytes;
16 use etherparse::EtherType;
17 use etherparse::LinkHeader::Ethernet2;
18 use etherparse::{NetHeaders, PacketBuilder, PacketHeaders, PayloadSlice, TransportHeader};
19 use libslirp_rs::libslirp::LibSlirp;
20 use libslirp_rs::libslirp_config::SlirpConfig;
21 use std::fs;
22 use std::io;
23 use std::net::{SocketAddr, UdpSocket};
24 use std::sync::mpsc;
25 use std::thread;
26 use std::time::Duration;
27
28 const PAYLOAD: &[u8; 23] = b"Hello, UDP echo server!";
29 const PAYLOAD_PONG: &[u8; 23] = b"Hello, UDP echo client!";
30
31 /// Test UDP packets sent through libslirp
32 #[cfg(not(windows))] // TOOD: remove once test is working on windows.
33 #[test]
udp_echo()34 fn udp_echo() {
35 let config = SlirpConfig { ..Default::default() };
36
37 let before_fd_count = count_open_fds().unwrap();
38
39 let (tx, rx) = mpsc::channel::<Bytes>();
40 let slirp = LibSlirp::new(config, tx, None);
41
42 // Start up an IPV4 UDP echo server
43 let server_addr = one_shot_udp_echo_server().unwrap();
44
45 println!("server addr {:?}", server_addr);
46 let server_ip = match server_addr {
47 SocketAddr::V4(addr) => addr.ip().to_owned(),
48 _ => panic!("Unsupported address type"),
49 };
50 // Source address
51 let source_ip = server_ip.clone();
52
53 // Source and destination ports
54 let source_port: u16 = 20000;
55 let destination_port = server_addr.port();
56
57 // Build the UDP packet
58 // with abitrary source and destination mac addrs
59 // We use server address 0.0.0.0 to avoid ARP packets
60 let builder = PacketBuilder::ethernet2([1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12])
61 .ipv4(source_ip.octets(), server_ip.octets(), 20)
62 .udp(source_port, destination_port);
63
64 // Get some memory to store the result
65 let mut result = Vec::<u8>::with_capacity(builder.size(PAYLOAD.len()));
66
67 // Serialize header and payload
68 builder.write(&mut result, PAYLOAD).unwrap();
69
70 let headers = PacketHeaders::from_ethernet_slice(&result).unwrap();
71 if let Some(Ethernet2(ether_header)) = headers.link {
72 assert_eq!(ether_header.ether_type, EtherType::IPV4);
73 } else {
74 panic!("expected ethernet2 header");
75 }
76
77 assert!(headers.net.is_some());
78 assert!(headers.transport.is_some());
79
80 // Send to oneshot_udp_echo_server (via libslirp)
81 slirp.input(Bytes::from(result));
82
83 // Read from oneshot_udp_echo server (via libslirp)
84 // No ARP packets will be seen
85
86 // Try to receive a packet before end_time
87 match rx.recv_timeout(Duration::from_secs(2)) {
88 Ok(packet) => {
89 let headers = PacketHeaders::from_ethernet_slice(&packet).unwrap();
90
91 if let Some(Ethernet2(ref ether_header)) = headers.link {
92 assert_eq!(ether_header.ether_type, EtherType::IPV4);
93 } else {
94 panic!("expected ethernet2 header");
95 }
96
97 if let Some(NetHeaders::Ipv4(ipv4_header, _)) = headers.net {
98 assert_eq!(ipv4_header.source, [127, 0, 0, 1]);
99 assert_eq!(ipv4_header.destination, [0, 0, 0, 0]);
100 } else {
101 panic!("expected IpV4 header, got {:?}", headers.net);
102 }
103
104 if let Some(TransportHeader::Udp(udp_header)) = headers.transport {
105 assert_eq!(udp_header.source_port, destination_port);
106 assert_eq!(udp_header.destination_port, source_port);
107 } else {
108 panic!("expected Udp header");
109 }
110
111 if let PayloadSlice::Udp(payload) = headers.payload {
112 assert_eq!(payload, PAYLOAD_PONG);
113 } else {
114 panic!("expected Udp payload");
115 }
116 }
117 Err(mpsc::RecvTimeoutError::Timeout) => {
118 assert!(false, "Timeout waiting for udp packet");
119 }
120 Err(e) => {
121 panic!("Failed to receive data in main thread: {}", e);
122 }
123 }
124
125 // validate data packet
126
127 slirp.shutdown();
128 assert_eq!(
129 rx.recv_timeout(Duration::from_millis(5)),
130 Err(mpsc::RecvTimeoutError::Disconnected)
131 );
132
133 let after_fd_count = count_open_fds().unwrap();
134 assert_eq!(before_fd_count, after_fd_count);
135 }
136
one_shot_udp_echo_server() -> std::io::Result<SocketAddr>137 fn one_shot_udp_echo_server() -> std::io::Result<SocketAddr> {
138 let socket = UdpSocket::bind("0.0.0.0:0")?;
139 let addr = socket.local_addr()?;
140 thread::spawn(move || {
141 let mut buf = [0u8; 1024];
142 let (len, addr) = socket.recv_from(&mut buf).unwrap();
143 let data = &buf[..len];
144 if data != PAYLOAD {
145 panic!("mistmatch payload");
146 }
147 println!("sending to addr {addr:?}");
148 let _ = socket.send_to(PAYLOAD_PONG, addr);
149 });
150 Ok(addr)
151 }
152
153 #[cfg(target_os = "linux")]
count_open_fds() -> io::Result<usize>154 fn count_open_fds() -> io::Result<usize> {
155 let entries = fs::read_dir("/proc/self/fd")?;
156 Ok(entries.count())
157 }
158
159 #[cfg(not(target_os = "linux"))]
count_open_fds() -> io::Result<usize>160 fn count_open_fds() -> io::Result<usize> {
161 Ok(0)
162 }
163