1 use std::ffi::CString;
2 use std::fs::File;
3 use std::io::Result;
4 use std::os::unix::io::{AsRawFd, FromRawFd};
5 use std::os::unix::net::UnixStream;
6 use std::path::Path;
7 use std::sync::{Arc, Barrier, Mutex};
8 use std::thread;
9 
10 use vhost::vhost_user::message::{
11     VhostUserConfigFlags, VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures,
12 };
13 use vhost::vhost_user::{Listener, Master, Slave, VhostUserMaster};
14 use vhost::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
15 use vhost_user_backend::{VhostUserBackendMut, VhostUserDaemon, VringRwLock};
16 use vm_memory::{
17     FileOffset, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic, GuestMemoryMmap,
18 };
19 use vmm_sys_util::epoll::EventSet;
20 use vmm_sys_util::eventfd::EventFd;
21 
22 struct MockVhostBackend {
23     events: u64,
24     event_idx: bool,
25     acked_features: u64,
26 }
27 
28 impl MockVhostBackend {
new() -> Self29     fn new() -> Self {
30         MockVhostBackend {
31             events: 0,
32             event_idx: false,
33             acked_features: 0,
34         }
35     }
36 }
37 
38 impl VhostUserBackendMut<VringRwLock, ()> for MockVhostBackend {
num_queues(&self) -> usize39     fn num_queues(&self) -> usize {
40         2
41     }
42 
max_queue_size(&self) -> usize43     fn max_queue_size(&self) -> usize {
44         256
45     }
46 
features(&self) -> u6447     fn features(&self) -> u64 {
48         0xffff_ffff_ffff_ffff
49     }
50 
acked_features(&mut self, features: u64)51     fn acked_features(&mut self, features: u64) {
52         self.acked_features = features;
53     }
54 
protocol_features(&self) -> VhostUserProtocolFeatures55     fn protocol_features(&self) -> VhostUserProtocolFeatures {
56         VhostUserProtocolFeatures::all()
57     }
58 
set_event_idx(&mut self, enabled: bool)59     fn set_event_idx(&mut self, enabled: bool) {
60         self.event_idx = enabled;
61     }
62 
get_config(&self, offset: u32, size: u32) -> Vec<u8>63     fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
64         assert_eq!(offset, 0x200);
65         assert_eq!(size, 8);
66 
67         vec![0xa5u8; 8]
68     }
69 
set_config(&mut self, offset: u32, buf: &[u8]) -> Result<()>70     fn set_config(&mut self, offset: u32, buf: &[u8]) -> Result<()> {
71         assert_eq!(offset, 0x200);
72         assert_eq!(buf, &[0xa5u8; 8]);
73 
74         Ok(())
75     }
76 
update_memory(&mut self, atomic_mem: GuestMemoryAtomic<GuestMemoryMmap>) -> Result<()>77     fn update_memory(&mut self, atomic_mem: GuestMemoryAtomic<GuestMemoryMmap>) -> Result<()> {
78         let mem = atomic_mem.memory();
79         let region = mem.find_region(GuestAddress(0x100000)).unwrap();
80         assert_eq!(region.size(), 0x100000);
81         Ok(())
82     }
83 
set_slave_req_fd(&mut self, _slave: Slave)84     fn set_slave_req_fd(&mut self, _slave: Slave) {}
85 
queues_per_thread(&self) -> Vec<u64>86     fn queues_per_thread(&self) -> Vec<u64> {
87         vec![1, 1]
88     }
89 
exit_event(&self, _thread_index: usize) -> Option<EventFd>90     fn exit_event(&self, _thread_index: usize) -> Option<EventFd> {
91         let event_fd = EventFd::new(0).unwrap();
92 
93         Some(event_fd)
94     }
95 
handle_event( &mut self, _device_event: u16, _evset: EventSet, _vrings: &[VringRwLock], _thread_id: usize, ) -> Result<bool>96     fn handle_event(
97         &mut self,
98         _device_event: u16,
99         _evset: EventSet,
100         _vrings: &[VringRwLock],
101         _thread_id: usize,
102     ) -> Result<bool> {
103         self.events += 1;
104 
105         Ok(false)
106     }
107 }
108 
setup_master(path: &Path, barrier: Arc<Barrier>) -> Master109 fn setup_master(path: &Path, barrier: Arc<Barrier>) -> Master {
110     barrier.wait();
111     let mut master = Master::connect(path, 1).unwrap();
112     master.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
113     // Wait before issue service requests.
114     barrier.wait();
115 
116     let features = master.get_features().unwrap();
117     let proto = master.get_protocol_features().unwrap();
118     master.set_features(features).unwrap();
119     master.set_protocol_features(proto).unwrap();
120     assert!(proto.contains(VhostUserProtocolFeatures::REPLY_ACK));
121 
122     master
123 }
124 
vhost_user_client(path: &Path, barrier: Arc<Barrier>)125 fn vhost_user_client(path: &Path, barrier: Arc<Barrier>) {
126     barrier.wait();
127     let mut master = Master::connect(path, 1).unwrap();
128     master.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
129     // Wait before issue service requests.
130     barrier.wait();
131 
132     let features = master.get_features().unwrap();
133     let proto = master.get_protocol_features().unwrap();
134     master.set_features(features).unwrap();
135     master.set_protocol_features(proto).unwrap();
136     assert!(proto.contains(VhostUserProtocolFeatures::REPLY_ACK));
137 
138     let queue_num = master.get_queue_num().unwrap();
139     assert_eq!(queue_num, 2);
140 
141     master.set_owner().unwrap();
142     //master.set_owner().unwrap_err();
143     master.reset_owner().unwrap();
144     master.reset_owner().unwrap();
145     master.set_owner().unwrap();
146 
147     master.set_features(features).unwrap();
148     master.set_protocol_features(proto).unwrap();
149     assert!(proto.contains(VhostUserProtocolFeatures::REPLY_ACK));
150 
151     let memfd = nix::sys::memfd::memfd_create(
152         &CString::new("test").unwrap(),
153         nix::sys::memfd::MemFdCreateFlag::empty(),
154     )
155     .unwrap();
156     // SAFETY: Safe because we panic before if memfd is not valid.
157     let file = unsafe { File::from_raw_fd(memfd) };
158     file.set_len(0x100000).unwrap();
159     let file_offset = FileOffset::new(file, 0);
160     let mem = GuestMemoryMmap::<()>::from_ranges_with_files(&[(
161         GuestAddress(0x100000),
162         0x100000,
163         Some(file_offset),
164     )])
165     .unwrap();
166     let addr = mem.get_host_address(GuestAddress(0x100000)).unwrap() as u64;
167     let reg = mem.find_region(GuestAddress(0x100000)).unwrap();
168     let fd = reg.file_offset().unwrap();
169     let regions = [VhostUserMemoryRegionInfo::new(
170         0x100000,
171         0x100000,
172         addr,
173         0,
174         fd.file().as_raw_fd(),
175     )];
176     master.set_mem_table(&regions).unwrap();
177 
178     master.set_vring_num(0, 256).unwrap();
179 
180     let config = VringConfigData {
181         queue_max_size: 256,
182         queue_size: 256,
183         flags: 0,
184         desc_table_addr: addr,
185         used_ring_addr: addr + 0x10000,
186         avail_ring_addr: addr + 0x20000,
187         log_addr: None,
188     };
189     master.set_vring_addr(0, &config).unwrap();
190 
191     let eventfd = EventFd::new(0).unwrap();
192     master.set_vring_kick(0, &eventfd).unwrap();
193     master.set_vring_call(0, &eventfd).unwrap();
194     master.set_vring_err(0, &eventfd).unwrap();
195     master.set_vring_enable(0, true).unwrap();
196 
197     let buf = [0u8; 8];
198     let (_cfg, data) = master
199         .get_config(0x200, 8, VhostUserConfigFlags::empty(), &buf)
200         .unwrap();
201     assert_eq!(&data, &[0xa5u8; 8]);
202     master
203         .set_config(0x200, VhostUserConfigFlags::empty(), &data)
204         .unwrap();
205 
206     let (tx, _rx) = UnixStream::pair().unwrap();
207     master.set_slave_request_fd(&tx).unwrap();
208 
209     let state = master.get_vring_base(0).unwrap();
210     master.set_vring_base(0, state as u16).unwrap();
211 
212     assert_eq!(master.get_max_mem_slots().unwrap(), 32);
213     let region = VhostUserMemoryRegionInfo::new(0x800000, 0x100000, addr, 0, fd.file().as_raw_fd());
214     master.add_mem_region(&region).unwrap();
215     master.remove_mem_region(&region).unwrap();
216 }
217 
vhost_user_server(cb: fn(&Path, Arc<Barrier>))218 fn vhost_user_server(cb: fn(&Path, Arc<Barrier>)) {
219     let mem = GuestMemoryAtomic::new(GuestMemoryMmap::<()>::new());
220     let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
221     let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap();
222 
223     let barrier = Arc::new(Barrier::new(2));
224     let tmpdir = tempfile::tempdir().unwrap();
225     let mut path = tmpdir.path().to_path_buf();
226     path.push("socket");
227 
228     let barrier2 = barrier.clone();
229     let path1 = path.clone();
230     let thread = thread::spawn(move || cb(&path1, barrier2));
231 
232     let listener = Listener::new(&path, false).unwrap();
233     barrier.wait();
234     daemon.start(listener).unwrap();
235     barrier.wait();
236 
237     // handle service requests from clients.
238     thread.join().unwrap();
239 }
240 
241 #[test]
test_vhost_user_server()242 fn test_vhost_user_server() {
243     vhost_user_server(vhost_user_client);
244 }
245 
vhost_user_enable(path: &Path, barrier: Arc<Barrier>)246 fn vhost_user_enable(path: &Path, barrier: Arc<Barrier>) {
247     let master = setup_master(path, barrier);
248     master.set_owner().unwrap();
249     master.set_owner().unwrap_err();
250 }
251 
252 #[test]
test_vhost_user_enable()253 fn test_vhost_user_enable() {
254     vhost_user_server(vhost_user_enable);
255 }
256 
vhost_user_set_inflight(path: &Path, barrier: Arc<Barrier>)257 fn vhost_user_set_inflight(path: &Path, barrier: Arc<Barrier>) {
258     let mut master = setup_master(path, barrier);
259     let eventfd = EventFd::new(0).unwrap();
260     // No implementation for inflight_fd yet.
261     let inflight = VhostUserInflight {
262         mmap_size: 0x100000,
263         mmap_offset: 0,
264         num_queues: 1,
265         queue_size: 256,
266     };
267     master
268         .set_inflight_fd(&inflight, eventfd.as_raw_fd())
269         .unwrap_err();
270 }
271 
272 #[test]
test_vhost_user_set_inflight()273 fn test_vhost_user_set_inflight() {
274     vhost_user_server(vhost_user_set_inflight);
275 }
276 
vhost_user_get_inflight(path: &Path, barrier: Arc<Barrier>)277 fn vhost_user_get_inflight(path: &Path, barrier: Arc<Barrier>) {
278     let mut master = setup_master(path, barrier);
279     // No implementation for inflight_fd yet.
280     let inflight = VhostUserInflight {
281         mmap_size: 0x100000,
282         mmap_offset: 0,
283         num_queues: 1,
284         queue_size: 256,
285     };
286     assert!(master.get_inflight_fd(&inflight).is_err());
287 }
288 
289 #[test]
test_vhost_user_get_inflight()290 fn test_vhost_user_get_inflight() {
291     vhost_user_server(vhost_user_get_inflight);
292 }
293