xref: /aosp_15_r20/external/crosvm/base/tests/tube.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::collections::HashMap;
6 use std::sync::Arc;
7 use std::sync::Barrier;
8 use std::thread;
9 use std::time::Duration;
10 
11 use base::descriptor::FromRawDescriptor;
12 use base::descriptor::SafeDescriptor;
13 use base::deserialize_with_descriptors;
14 use base::Event;
15 use base::EventToken;
16 use base::ReadNotifier;
17 use base::RecvTube;
18 use base::SendTube;
19 use base::SerializeDescriptors;
20 use base::Tube;
21 use base::WaitContext;
22 use serde::Deserialize;
23 use serde::Serialize;
24 
25 #[derive(Serialize, Deserialize)]
26 struct DataStruct {
27     x: u32,
28 }
29 
30 #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
31 enum Token {
32     ReceivedData,
33 }
34 
35 // Magics to identify which producer sent a message (& detect corruption).
36 const PRODUCER_ID_1: u32 = 801279273;
37 const PRODUCER_ID_2: u32 = 345234861;
38 
39 #[track_caller]
test_event_pair(send: Event, recv: Event)40 fn test_event_pair(send: Event, recv: Event) {
41     send.signal().unwrap();
42     recv.wait_timeout(Duration::from_secs(1)).unwrap();
43 }
44 
45 #[test]
send_recv_no_fd()46 fn send_recv_no_fd() {
47     let (s1, s2) = Tube::pair().unwrap();
48 
49     let test_msg = "hello world";
50     s1.send(&test_msg).unwrap();
51     let recv_msg: String = s2.recv().unwrap();
52 
53     assert_eq!(test_msg, recv_msg);
54 }
55 
56 #[test]
send_recv_one_fd()57 fn send_recv_one_fd() {
58     #[derive(Serialize, Deserialize)]
59     struct EventStruct {
60         x: u32,
61         b: Event,
62     }
63 
64     let (s1, s2) = Tube::pair().unwrap();
65 
66     let test_msg = EventStruct {
67         x: 100,
68         b: Event::new().unwrap(),
69     };
70     s1.send(&test_msg).unwrap();
71     let recv_msg: EventStruct = s2.recv().unwrap();
72 
73     assert_eq!(test_msg.x, recv_msg.x);
74 
75     test_event_pair(test_msg.b, recv_msg.b);
76 }
77 
78 #[test]
send_recv_event()79 fn send_recv_event() {
80     let (req, res) = Tube::pair().unwrap();
81     let e1 = Event::new().unwrap();
82     res.send(&e1).unwrap();
83 
84     let recv_event: Event = req.recv().unwrap();
85     recv_event.signal().unwrap();
86     e1.wait().unwrap();
87 }
88 
89 /// Send messages to a Tube with the given identifier (see `consume_messages`; we use this to
90 /// track different message producers).
91 #[track_caller]
produce_messages(tube: SendTube, data: u32, barrier: Arc<Barrier>) -> SendTube92 fn produce_messages(tube: SendTube, data: u32, barrier: Arc<Barrier>) -> SendTube {
93     let data = DataStruct { x: data };
94     barrier.wait();
95     for _ in 0..100 {
96         tube.send(&data).unwrap();
97     }
98     tube
99 }
100 
101 /// Consumes the given number of messages from a Tube, returning the number messages read with
102 /// each producer ID.
103 #[track_caller]
consume_messages( tube: RecvTube, count: usize, barrier: Arc<Barrier>, ) -> (RecvTube, usize, usize)104 fn consume_messages(
105     tube: RecvTube,
106     count: usize,
107     barrier: Arc<Barrier>,
108 ) -> (RecvTube, usize, usize) {
109     barrier.wait();
110 
111     let mut id1_count = 0usize;
112     let mut id2_count = 0usize;
113 
114     for _ in 0..count {
115         let msg = tube.recv::<DataStruct>().unwrap();
116         match msg.x {
117             PRODUCER_ID_1 => id1_count += 1,
118             PRODUCER_ID_2 => id2_count += 1,
119             _ => panic!(
120                 "want message with ID {} or {}; got message w/ ID {}.",
121                 PRODUCER_ID_1, PRODUCER_ID_2, msg.x
122             ),
123         }
124     }
125     (tube, id1_count, id2_count)
126 }
127 
128 #[test]
test_serialize_tube_pair()129 fn test_serialize_tube_pair() {
130     let (tube_send, tube_recv) = Tube::pair().unwrap();
131 
132     // Serialize the Tube
133     let msg_serialize = SerializeDescriptors::new(&tube_send);
134     let serialized = serde_json::to_vec(&msg_serialize).unwrap();
135     let msg_descriptors = msg_serialize.into_descriptors();
136 
137     // Deserialize the Tube
138     let msg_descriptors_safe = msg_descriptors.into_iter().map(|v|
139             // SAFETY: `v` is expected to be valid
140             unsafe { SafeDescriptor::from_raw_descriptor(v) });
141     let tube_deserialized: Tube =
142         deserialize_with_descriptors(|| serde_json::from_slice(&serialized), msg_descriptors_safe)
143             .unwrap();
144 
145     // Send a message through deserialized Tube
146     tube_deserialized.send(&"hi".to_string()).unwrap();
147 
148     // Wait for the message to arrive
149     let wait_ctx: WaitContext<Token> =
150         WaitContext::build_with(&[(tube_recv.get_read_notifier(), Token::ReceivedData)]).unwrap();
151     let events = wait_ctx.wait_timeout(Duration::from_secs(10)).unwrap();
152     let tokens: Vec<Token> = events
153         .iter()
154         .filter(|e| e.is_readable)
155         .map(|e| e.token)
156         .collect();
157     assert_eq!(tokens, vec! {Token::ReceivedData});
158 
159     assert_eq!(tube_recv.recv::<String>().unwrap(), "hi");
160 }
161 
162 #[test]
send_recv_mpsc()163 fn send_recv_mpsc() {
164     let (p1, consumer) = Tube::directional_pair().unwrap();
165     let p2 = p1.try_clone().unwrap();
166     let start_block_p1 = Arc::new(Barrier::new(3));
167     let start_block_p2 = start_block_p1.clone();
168     let start_block_consumer = start_block_p1.clone();
169 
170     let p1_thread = thread::spawn(move || produce_messages(p1, PRODUCER_ID_1, start_block_p1));
171     let p2_thread = thread::spawn(move || produce_messages(p2, PRODUCER_ID_2, start_block_p2));
172 
173     let (_tube, id1_count, id2_count) = consume_messages(consumer, 200, start_block_consumer);
174     assert_eq!(id1_count, 100);
175     assert_eq!(id2_count, 100);
176 
177     p1_thread.join().unwrap();
178     p2_thread.join().unwrap();
179 }
180 
181 #[test]
send_recv_hash_map()182 fn send_recv_hash_map() {
183     let (s1, s2) = Tube::pair().unwrap();
184 
185     let mut test_msg = HashMap::new();
186     test_msg.insert("Red".to_owned(), Event::new().unwrap());
187     test_msg.insert("White".to_owned(), Event::new().unwrap());
188     test_msg.insert("Blue".to_owned(), Event::new().unwrap());
189     test_msg.insert("Orange".to_owned(), Event::new().unwrap());
190     test_msg.insert("Green".to_owned(), Event::new().unwrap());
191     s1.send(&test_msg).unwrap();
192     let mut recv_msg: HashMap<String, Event> = s2.recv().unwrap();
193 
194     let mut test_msg_keys: Vec<_> = test_msg.keys().collect();
195     test_msg_keys.sort();
196     let mut recv_msg_keys: Vec<_> = recv_msg.keys().collect();
197     recv_msg_keys.sort();
198     assert_eq!(test_msg_keys, recv_msg_keys);
199 
200     for (key, test_event) in test_msg {
201         let recv_event = recv_msg.remove(&key).unwrap();
202         test_event_pair(test_event, recv_event);
203     }
204 }
205