xref: /aosp_15_r20/tools/netsim/rust/libslirp-rs/tests/integration_udp.rs (revision cf78ab8cffb8fc9207af348f23af247fb04370a6)
1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use bytes::Bytes;
16 use etherparse::EtherType;
17 use etherparse::LinkHeader::Ethernet2;
18 use etherparse::{NetHeaders, PacketBuilder, PacketHeaders, PayloadSlice, TransportHeader};
19 use libslirp_rs::libslirp::LibSlirp;
20 use libslirp_rs::libslirp_config::SlirpConfig;
21 use std::fs;
22 use std::io;
23 use std::net::{SocketAddr, UdpSocket};
24 use std::sync::mpsc;
25 use std::thread;
26 use std::time::Duration;
27 
28 const PAYLOAD: &[u8; 23] = b"Hello, UDP echo server!";
29 const PAYLOAD_PONG: &[u8; 23] = b"Hello, UDP echo client!";
30 
31 /// Test UDP packets sent through libslirp
32 #[cfg(not(windows))] // TOOD: remove once test is working on windows.
33 #[test]
udp_echo()34 fn udp_echo() {
35     let config = SlirpConfig { ..Default::default() };
36 
37     let before_fd_count = count_open_fds().unwrap();
38 
39     let (tx, rx) = mpsc::channel::<Bytes>();
40     let slirp = LibSlirp::new(config, tx, None);
41 
42     // Start up an IPV4 UDP echo server
43     let server_addr = one_shot_udp_echo_server().unwrap();
44 
45     println!("server addr {:?}", server_addr);
46     let server_ip = match server_addr {
47         SocketAddr::V4(addr) => addr.ip().to_owned(),
48         _ => panic!("Unsupported address type"),
49     };
50     // Source address
51     let source_ip = server_ip.clone();
52 
53     // Source and destination ports
54     let source_port: u16 = 20000;
55     let destination_port = server_addr.port();
56 
57     // Build the UDP packet
58     // with abitrary source and destination mac addrs
59     // We use server address 0.0.0.0 to avoid ARP packets
60     let builder = PacketBuilder::ethernet2([1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12])
61         .ipv4(source_ip.octets(), server_ip.octets(), 20)
62         .udp(source_port, destination_port);
63 
64     // Get some memory to store the result
65     let mut result = Vec::<u8>::with_capacity(builder.size(PAYLOAD.len()));
66 
67     // Serialize header and payload
68     builder.write(&mut result, PAYLOAD).unwrap();
69 
70     let headers = PacketHeaders::from_ethernet_slice(&result).unwrap();
71     if let Some(Ethernet2(ether_header)) = headers.link {
72         assert_eq!(ether_header.ether_type, EtherType::IPV4);
73     } else {
74         panic!("expected ethernet2 header");
75     }
76 
77     assert!(headers.net.is_some());
78     assert!(headers.transport.is_some());
79 
80     // Send to oneshot_udp_echo_server (via libslirp)
81     slirp.input(Bytes::from(result));
82 
83     // Read from oneshot_udp_echo server (via libslirp)
84     // No ARP packets will be seen
85 
86     // Try to receive a packet before end_time
87     match rx.recv_timeout(Duration::from_secs(2)) {
88         Ok(packet) => {
89             let headers = PacketHeaders::from_ethernet_slice(&packet).unwrap();
90 
91             if let Some(Ethernet2(ref ether_header)) = headers.link {
92                 assert_eq!(ether_header.ether_type, EtherType::IPV4);
93             } else {
94                 panic!("expected ethernet2 header");
95             }
96 
97             if let Some(NetHeaders::Ipv4(ipv4_header, _)) = headers.net {
98                 assert_eq!(ipv4_header.source, [127, 0, 0, 1]);
99                 assert_eq!(ipv4_header.destination, [0, 0, 0, 0]);
100             } else {
101                 panic!("expected IpV4 header, got {:?}", headers.net);
102             }
103 
104             if let Some(TransportHeader::Udp(udp_header)) = headers.transport {
105                 assert_eq!(udp_header.source_port, destination_port);
106                 assert_eq!(udp_header.destination_port, source_port);
107             } else {
108                 panic!("expected Udp header");
109             }
110 
111             if let PayloadSlice::Udp(payload) = headers.payload {
112                 assert_eq!(payload, PAYLOAD_PONG);
113             } else {
114                 panic!("expected Udp payload");
115             }
116         }
117         Err(mpsc::RecvTimeoutError::Timeout) => {
118             assert!(false, "Timeout waiting for udp packet");
119         }
120         Err(e) => {
121             panic!("Failed to receive data in main thread: {}", e);
122         }
123     }
124 
125     // validate data packet
126 
127     slirp.shutdown();
128     assert_eq!(
129         rx.recv_timeout(Duration::from_millis(5)),
130         Err(mpsc::RecvTimeoutError::Disconnected)
131     );
132 
133     let after_fd_count = count_open_fds().unwrap();
134     assert_eq!(before_fd_count, after_fd_count);
135 }
136 
one_shot_udp_echo_server() -> std::io::Result<SocketAddr>137 fn one_shot_udp_echo_server() -> std::io::Result<SocketAddr> {
138     let socket = UdpSocket::bind("0.0.0.0:0")?;
139     let addr = socket.local_addr()?;
140     thread::spawn(move || {
141         let mut buf = [0u8; 1024];
142         let (len, addr) = socket.recv_from(&mut buf).unwrap();
143         let data = &buf[..len];
144         if data != PAYLOAD {
145             panic!("mistmatch payload");
146         }
147         println!("sending to addr {addr:?}");
148         let _ = socket.send_to(PAYLOAD_PONG, addr);
149     });
150     Ok(addr)
151 }
152 
153 #[cfg(target_os = "linux")]
count_open_fds() -> io::Result<usize>154 fn count_open_fds() -> io::Result<usize> {
155     let entries = fs::read_dir("/proc/self/fd")?;
156     Ok(entries.count())
157 }
158 
159 #[cfg(not(target_os = "linux"))]
count_open_fds() -> io::Result<usize>160 fn count_open_fds() -> io::Result<usize> {
161     Ok(0)
162 }
163