1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::collections::HashMap;
6 use std::sync::Arc;
7 use std::sync::Barrier;
8 use std::thread;
9 use std::time::Duration;
10
11 use base::descriptor::FromRawDescriptor;
12 use base::descriptor::SafeDescriptor;
13 use base::deserialize_with_descriptors;
14 use base::Event;
15 use base::EventToken;
16 use base::ReadNotifier;
17 use base::RecvTube;
18 use base::SendTube;
19 use base::SerializeDescriptors;
20 use base::Tube;
21 use base::WaitContext;
22 use serde::Deserialize;
23 use serde::Serialize;
24
25 #[derive(Serialize, Deserialize)]
26 struct DataStruct {
27 x: u32,
28 }
29
30 #[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
31 enum Token {
32 ReceivedData,
33 }
34
35 // Magics to identify which producer sent a message (& detect corruption).
36 const PRODUCER_ID_1: u32 = 801279273;
37 const PRODUCER_ID_2: u32 = 345234861;
38
39 #[track_caller]
test_event_pair(send: Event, recv: Event)40 fn test_event_pair(send: Event, recv: Event) {
41 send.signal().unwrap();
42 recv.wait_timeout(Duration::from_secs(1)).unwrap();
43 }
44
45 #[test]
send_recv_no_fd()46 fn send_recv_no_fd() {
47 let (s1, s2) = Tube::pair().unwrap();
48
49 let test_msg = "hello world";
50 s1.send(&test_msg).unwrap();
51 let recv_msg: String = s2.recv().unwrap();
52
53 assert_eq!(test_msg, recv_msg);
54 }
55
56 #[test]
send_recv_one_fd()57 fn send_recv_one_fd() {
58 #[derive(Serialize, Deserialize)]
59 struct EventStruct {
60 x: u32,
61 b: Event,
62 }
63
64 let (s1, s2) = Tube::pair().unwrap();
65
66 let test_msg = EventStruct {
67 x: 100,
68 b: Event::new().unwrap(),
69 };
70 s1.send(&test_msg).unwrap();
71 let recv_msg: EventStruct = s2.recv().unwrap();
72
73 assert_eq!(test_msg.x, recv_msg.x);
74
75 test_event_pair(test_msg.b, recv_msg.b);
76 }
77
78 #[test]
send_recv_event()79 fn send_recv_event() {
80 let (req, res) = Tube::pair().unwrap();
81 let e1 = Event::new().unwrap();
82 res.send(&e1).unwrap();
83
84 let recv_event: Event = req.recv().unwrap();
85 recv_event.signal().unwrap();
86 e1.wait().unwrap();
87 }
88
89 /// Send messages to a Tube with the given identifier (see `consume_messages`; we use this to
90 /// track different message producers).
91 #[track_caller]
produce_messages(tube: SendTube, data: u32, barrier: Arc<Barrier>) -> SendTube92 fn produce_messages(tube: SendTube, data: u32, barrier: Arc<Barrier>) -> SendTube {
93 let data = DataStruct { x: data };
94 barrier.wait();
95 for _ in 0..100 {
96 tube.send(&data).unwrap();
97 }
98 tube
99 }
100
101 /// Consumes the given number of messages from a Tube, returning the number messages read with
102 /// each producer ID.
103 #[track_caller]
consume_messages( tube: RecvTube, count: usize, barrier: Arc<Barrier>, ) -> (RecvTube, usize, usize)104 fn consume_messages(
105 tube: RecvTube,
106 count: usize,
107 barrier: Arc<Barrier>,
108 ) -> (RecvTube, usize, usize) {
109 barrier.wait();
110
111 let mut id1_count = 0usize;
112 let mut id2_count = 0usize;
113
114 for _ in 0..count {
115 let msg = tube.recv::<DataStruct>().unwrap();
116 match msg.x {
117 PRODUCER_ID_1 => id1_count += 1,
118 PRODUCER_ID_2 => id2_count += 1,
119 _ => panic!(
120 "want message with ID {} or {}; got message w/ ID {}.",
121 PRODUCER_ID_1, PRODUCER_ID_2, msg.x
122 ),
123 }
124 }
125 (tube, id1_count, id2_count)
126 }
127
128 #[test]
test_serialize_tube_pair()129 fn test_serialize_tube_pair() {
130 let (tube_send, tube_recv) = Tube::pair().unwrap();
131
132 // Serialize the Tube
133 let msg_serialize = SerializeDescriptors::new(&tube_send);
134 let serialized = serde_json::to_vec(&msg_serialize).unwrap();
135 let msg_descriptors = msg_serialize.into_descriptors();
136
137 // Deserialize the Tube
138 let msg_descriptors_safe = msg_descriptors.into_iter().map(|v|
139 // SAFETY: `v` is expected to be valid
140 unsafe { SafeDescriptor::from_raw_descriptor(v) });
141 let tube_deserialized: Tube =
142 deserialize_with_descriptors(|| serde_json::from_slice(&serialized), msg_descriptors_safe)
143 .unwrap();
144
145 // Send a message through deserialized Tube
146 tube_deserialized.send(&"hi".to_string()).unwrap();
147
148 // Wait for the message to arrive
149 let wait_ctx: WaitContext<Token> =
150 WaitContext::build_with(&[(tube_recv.get_read_notifier(), Token::ReceivedData)]).unwrap();
151 let events = wait_ctx.wait_timeout(Duration::from_secs(10)).unwrap();
152 let tokens: Vec<Token> = events
153 .iter()
154 .filter(|e| e.is_readable)
155 .map(|e| e.token)
156 .collect();
157 assert_eq!(tokens, vec! {Token::ReceivedData});
158
159 assert_eq!(tube_recv.recv::<String>().unwrap(), "hi");
160 }
161
162 #[test]
send_recv_mpsc()163 fn send_recv_mpsc() {
164 let (p1, consumer) = Tube::directional_pair().unwrap();
165 let p2 = p1.try_clone().unwrap();
166 let start_block_p1 = Arc::new(Barrier::new(3));
167 let start_block_p2 = start_block_p1.clone();
168 let start_block_consumer = start_block_p1.clone();
169
170 let p1_thread = thread::spawn(move || produce_messages(p1, PRODUCER_ID_1, start_block_p1));
171 let p2_thread = thread::spawn(move || produce_messages(p2, PRODUCER_ID_2, start_block_p2));
172
173 let (_tube, id1_count, id2_count) = consume_messages(consumer, 200, start_block_consumer);
174 assert_eq!(id1_count, 100);
175 assert_eq!(id2_count, 100);
176
177 p1_thread.join().unwrap();
178 p2_thread.join().unwrap();
179 }
180
181 #[test]
send_recv_hash_map()182 fn send_recv_hash_map() {
183 let (s1, s2) = Tube::pair().unwrap();
184
185 let mut test_msg = HashMap::new();
186 test_msg.insert("Red".to_owned(), Event::new().unwrap());
187 test_msg.insert("White".to_owned(), Event::new().unwrap());
188 test_msg.insert("Blue".to_owned(), Event::new().unwrap());
189 test_msg.insert("Orange".to_owned(), Event::new().unwrap());
190 test_msg.insert("Green".to_owned(), Event::new().unwrap());
191 s1.send(&test_msg).unwrap();
192 let mut recv_msg: HashMap<String, Event> = s2.recv().unwrap();
193
194 let mut test_msg_keys: Vec<_> = test_msg.keys().collect();
195 test_msg_keys.sort();
196 let mut recv_msg_keys: Vec<_> = recv_msg.keys().collect();
197 recv_msg_keys.sort();
198 assert_eq!(test_msg_keys, recv_msg_keys);
199
200 for (key, test_event) in test_msg {
201 let recv_event = recv_msg.remove(&key).unwrap();
202 test_event_pair(test_event, recv_event);
203 }
204 }
205