xref: /aosp_15_r20/external/crosvm/third_party/vmm_vhost/src/backend_server.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::fs::File;
5 use std::mem;
6 
7 use base::AsRawDescriptor;
8 use base::RawDescriptor;
9 use base::SafeDescriptor;
10 use zerocopy::AsBytes;
11 use zerocopy::FromBytes;
12 use zerocopy::Ref;
13 
14 use crate::into_single_file;
15 use crate::message::*;
16 use crate::BackendReq;
17 use crate::Connection;
18 use crate::Error;
19 use crate::FrontendReq;
20 use crate::Result;
21 
22 /// Trait for vhost-user backends.
23 ///
24 /// Each method corresponds to a vhost-user protocol method. See the specification for details.
25 #[allow(missing_docs)]
26 pub trait Backend {
set_owner(&mut self) -> Result<()>27     fn set_owner(&mut self) -> Result<()>;
reset_owner(&mut self) -> Result<()>28     fn reset_owner(&mut self) -> Result<()>;
get_features(&mut self) -> Result<u64>29     fn get_features(&mut self) -> Result<u64>;
set_features(&mut self, features: u64) -> Result<()>30     fn set_features(&mut self, features: u64) -> Result<()>;
set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>31     fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
set_vring_num(&mut self, index: u32, num: u32) -> Result<()>32     fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>33     fn set_vring_addr(
34         &mut self,
35         index: u32,
36         flags: VhostUserVringAddrFlags,
37         descriptor: u64,
38         used: u64,
39         available: u64,
40         log: u64,
41     ) -> Result<()>;
42     // TODO: b/331466964 - Argument type is wrong for packed queues.
set_vring_base(&mut self, index: u32, base: u32) -> Result<()>43     fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
44     // TODO: b/331466964 - Return type is wrong for packed queues.
get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>45     fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>46     fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>47     fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>48     fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
49 
get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>50     fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
set_protocol_features(&mut self, features: u64) -> Result<()>51     fn set_protocol_features(&mut self, features: u64) -> Result<()>;
get_queue_num(&mut self) -> Result<u64>52     fn get_queue_num(&mut self) -> Result<u64>;
set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>53     fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>54     fn get_config(
55         &mut self,
56         offset: u32,
57         size: u32,
58         flags: VhostUserConfigFlags,
59     ) -> Result<Vec<u8>>;
set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>60     fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>)61     fn set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>) {}
get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>62     fn get_inflight_fd(
63         &mut self,
64         inflight: &VhostUserInflight,
65     ) -> Result<(VhostUserInflight, File)>;
set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>66     fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
get_max_mem_slots(&mut self) -> Result<u64>67     fn get_max_mem_slots(&mut self) -> Result<u64>;
add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>68     fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>69     fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
set_device_state_fd( &mut self, transfer_direction: VhostUserTransferDirection, migration_phase: VhostUserMigrationPhase, fd: File, ) -> Result<Option<File>>70     fn set_device_state_fd(
71         &mut self,
72         transfer_direction: VhostUserTransferDirection,
73         migration_phase: VhostUserMigrationPhase,
74         fd: File,
75     ) -> Result<Option<File>>;
check_device_state(&mut self) -> Result<()>76     fn check_device_state(&mut self) -> Result<()>;
get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>77     fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>;
78 }
79 
80 impl<T> Backend for T
81 where
82     T: AsMut<dyn Backend>,
83 {
set_owner(&mut self) -> Result<()>84     fn set_owner(&mut self) -> Result<()> {
85         self.as_mut().set_owner()
86     }
87 
reset_owner(&mut self) -> Result<()>88     fn reset_owner(&mut self) -> Result<()> {
89         self.as_mut().reset_owner()
90     }
91 
get_features(&mut self) -> Result<u64>92     fn get_features(&mut self) -> Result<u64> {
93         self.as_mut().get_features()
94     }
95 
set_features(&mut self, features: u64) -> Result<()>96     fn set_features(&mut self, features: u64) -> Result<()> {
97         self.as_mut().set_features(features)
98     }
99 
set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>100     fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
101         self.as_mut().set_mem_table(ctx, files)
102     }
103 
set_vring_num(&mut self, index: u32, num: u32) -> Result<()>104     fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
105         self.as_mut().set_vring_num(index, num)
106     }
107 
set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>108     fn set_vring_addr(
109         &mut self,
110         index: u32,
111         flags: VhostUserVringAddrFlags,
112         descriptor: u64,
113         used: u64,
114         available: u64,
115         log: u64,
116     ) -> Result<()> {
117         self.as_mut()
118             .set_vring_addr(index, flags, descriptor, used, available, log)
119     }
120 
set_vring_base(&mut self, index: u32, base: u32) -> Result<()>121     fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
122         self.as_mut().set_vring_base(index, base)
123     }
124 
get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>125     fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
126         self.as_mut().get_vring_base(index)
127     }
128 
set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>129     fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
130         self.as_mut().set_vring_kick(index, fd)
131     }
132 
set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>133     fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
134         self.as_mut().set_vring_call(index, fd)
135     }
136 
set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>137     fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
138         self.as_mut().set_vring_err(index, fd)
139     }
140 
get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>141     fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
142         self.as_mut().get_protocol_features()
143     }
144 
set_protocol_features(&mut self, features: u64) -> Result<()>145     fn set_protocol_features(&mut self, features: u64) -> Result<()> {
146         self.as_mut().set_protocol_features(features)
147     }
148 
get_queue_num(&mut self) -> Result<u64>149     fn get_queue_num(&mut self) -> Result<u64> {
150         self.as_mut().get_queue_num()
151     }
152 
set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>153     fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
154         self.as_mut().set_vring_enable(index, enable)
155     }
156 
get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>157     fn get_config(
158         &mut self,
159         offset: u32,
160         size: u32,
161         flags: VhostUserConfigFlags,
162     ) -> Result<Vec<u8>> {
163         self.as_mut().get_config(offset, size, flags)
164     }
165 
set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>166     fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
167         self.as_mut().set_config(offset, buf, flags)
168     }
169 
set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>)170     fn set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>) {
171         self.as_mut().set_backend_req_fd(vu_req)
172     }
173 
get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>174     fn get_inflight_fd(
175         &mut self,
176         inflight: &VhostUserInflight,
177     ) -> Result<(VhostUserInflight, File)> {
178         self.as_mut().get_inflight_fd(inflight)
179     }
180 
set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>181     fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> {
182         self.as_mut().set_inflight_fd(inflight, file)
183     }
184 
get_max_mem_slots(&mut self) -> Result<u64>185     fn get_max_mem_slots(&mut self) -> Result<u64> {
186         self.as_mut().get_max_mem_slots()
187     }
188 
add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>189     fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
190         self.as_mut().add_mem_region(region, fd)
191     }
192 
remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>193     fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
194         self.as_mut().remove_mem_region(region)
195     }
196 
set_device_state_fd( &mut self, transfer_direction: VhostUserTransferDirection, migration_phase: VhostUserMigrationPhase, fd: File, ) -> Result<Option<File>>197     fn set_device_state_fd(
198         &mut self,
199         transfer_direction: VhostUserTransferDirection,
200         migration_phase: VhostUserMigrationPhase,
201         fd: File,
202     ) -> Result<Option<File>> {
203         self.as_mut()
204             .set_device_state_fd(transfer_direction, migration_phase, fd)
205     }
206 
check_device_state(&mut self) -> Result<()>207     fn check_device_state(&mut self) -> Result<()> {
208         self.as_mut().check_device_state()
209     }
210 
get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>211     fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>> {
212         self.as_mut().get_shared_memory_regions()
213     }
214 }
215 
216 /// Handles requests from a vhost-user connection by dispatching them to [[Backend]] methods.
217 pub struct BackendServer<S: Backend> {
218     /// Underlying connection for communication.
219     connection: Connection<FrontendReq>,
220     // the vhost-user backend device object
221     backend: S,
222 
223     virtio_features: u64,
224     acked_virtio_features: u64,
225     protocol_features: VhostUserProtocolFeatures,
226     acked_protocol_features: u64,
227 
228     /// Sending ack for messages without payload.
229     reply_ack_enabled: bool,
230 }
231 
232 impl<S: Backend> AsRef<S> for BackendServer<S> {
as_ref(&self) -> &S233     fn as_ref(&self) -> &S {
234         &self.backend
235     }
236 }
237 
238 impl<S: Backend> BackendServer<S> {
new(connection: Connection<FrontendReq>, backend: S) -> Self239     pub fn new(connection: Connection<FrontendReq>, backend: S) -> Self {
240         BackendServer {
241             connection,
242             backend,
243             virtio_features: 0,
244             acked_virtio_features: 0,
245             protocol_features: VhostUserProtocolFeatures::empty(),
246             acked_protocol_features: 0,
247             reply_ack_enabled: false,
248         }
249     }
250 
251     /// Receives and validates a vhost-user message header and optional files.
252     ///
253     /// Since the length of vhost-user messages are different among message types, regular
254     /// vhost-user messages are sent via an underlying communication channel in stream mode.
255     /// (e.g. `SOCK_STREAM` in UNIX)
256     /// So, the logic of receiving and handling a message consists of the following steps:
257     ///
258     /// 1. Receives a message header and optional attached file.
259     /// 2. Validates the message header.
260     /// 3. Check if optional payloads is expected.
261     /// 4. Wait for the optional payloads.
262     /// 5. Receives optional payloads.
263     /// 6. Processes the message.
264     ///
265     /// This method [`BackendServer::recv_header()`] is in charge of the step (1) and (2),
266     /// [`BackendServer::needs_wait_for_payload()`] is (3), and
267     /// [`BackendServer::process_message()`] is (5) and (6). We need to have the three method
268     /// separately for multi-platform supports; [`BackendServer::recv_header()`] and
269     /// [`BackendServer::process_message()`] need to be separated because the way of waiting for
270     /// incoming messages differs between Unix and Windows so it's the caller's responsibility to
271     /// wait before [`BackendServer::process_message()`].
272     ///
273     /// Note that some vhost-user protocol variant such as VVU doesn't assume stream mode. In this
274     /// case, a message header and its body are sent together so the step (4) is skipped. We handle
275     /// this case in [`BackendServer::needs_wait_for_payload()`].
276     ///
277     /// The following pseudo code describes how a caller should process incoming vhost-user
278     /// messages:
279     /// ```ignore
280     /// loop {
281     ///   // block until a message header comes.
282     ///   // The actual code differs, depending on platforms.
283     ///   connection.wait_readable().unwrap();
284     ///
285     ///   // (1) and (2)
286     ///   let (hdr, files) = backend_server.recv_header();
287     ///
288     ///   // (3)
289     ///   if backend_server.needs_wait_for_payload(&hdr) {
290     ///     // (4) block until a payload comes if needed.
291     ///     connection.wait_readable().unwrap();
292     ///   }
293     ///
294     ///   // (5) and (6)
295     ///   backend_server.process_message(&hdr, &files).unwrap();
296     /// }
297     /// ```
recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)>298     pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)> {
299         // The underlying communication channel is a Unix domain socket in
300         // stream mode, and recvmsg() is a little tricky here. To successfully
301         // receive attached file descriptors, we need to receive messages and
302         // corresponding attached file descriptors in this way:
303         // . recv messsage header and optional attached file
304         // . validate message header
305         // . recv optional message body and payload according size field in
306         //   message header
307         // . validate message body and optional payload
308         let (hdr, files) = match self.connection.recv_header() {
309             Ok((hdr, files)) => (hdr, files),
310             Err(Error::Disconnect) => {
311                 // If the client closed the connection before sending a header, this should be
312                 // handled as a legal exit.
313                 return Err(Error::ClientExit);
314             }
315             Err(e) => {
316                 return Err(e);
317             }
318         };
319 
320         self.check_attached_files(&hdr, &files)?;
321 
322         Ok((hdr, files))
323     }
324 
325     /// Returns whether the caller needs to wait for the incoming message before calling
326     /// [`BackendServer::process_message`].
327     ///
328     /// See [`BackendServer::recv_header`]'s doc comment for the usage.
needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool329     pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool {
330         // Since the vhost-user protocol uses stream mode, we need to wait until an additional
331         // payload is available if exists.
332         hdr.get_size() != 0
333     }
334 
335     /// Main entrance to request from the communication channel.
336     ///
337     /// Receive and handle one incoming request message from the frontend.
338     /// See [`BackendServer::recv_header`]'s doc comment for the usage.
339     ///
340     /// # Return:
341     /// * `Ok(())`: one request was successfully handled.
342     /// * `Err(ClientExit)`: the frontend closed the connection properly. This isn't an actual
343     ///   failure.
344     /// * `Err(Disconnect)`: the connection was closed unexpectedly.
345     /// * `Err(InvalidMessage)`: the vmm sent a illegal message.
346     /// * other errors: failed to handle a request.
process_message( &mut self, hdr: VhostUserMsgHeader<FrontendReq>, files: Vec<File>, ) -> Result<()>347     pub fn process_message(
348         &mut self,
349         hdr: VhostUserMsgHeader<FrontendReq>,
350         files: Vec<File>,
351     ) -> Result<()> {
352         let buf = self.connection.recv_body_bytes(&hdr)?;
353         let size = buf.len();
354 
355         // TODO: The error handling here is inconsistent. Sometimes we report the error to the
356         // client and keep going, sometimes we report the error and then close the connection,
357         // sometimes we just close the connection.
358         match hdr.get_code() {
359             Ok(FrontendReq::SET_OWNER) => {
360                 self.check_request_size(&hdr, size, 0)?;
361                 let res = self.backend.set_owner();
362                 self.send_ack_message(&hdr, res.is_ok())?;
363                 res?;
364             }
365             Ok(FrontendReq::RESET_OWNER) => {
366                 self.check_request_size(&hdr, size, 0)?;
367                 let res = self.backend.reset_owner();
368                 self.send_ack_message(&hdr, res.is_ok())?;
369                 res?;
370             }
371             Ok(FrontendReq::GET_FEATURES) => {
372                 self.check_request_size(&hdr, size, 0)?;
373                 let mut features = self.backend.get_features()?;
374 
375                 // Don't advertise packed queues even if the device does. We don't handle them
376                 // properly yet at the protocol layer.
377                 // TODO: b/331466964 - Remove once support is added.
378                 features &= !(1 << VIRTIO_F_RING_PACKED);
379 
380                 let msg = VhostUserU64::new(features);
381                 self.send_reply_message(&hdr, &msg)?;
382                 self.virtio_features = features;
383                 self.update_reply_ack_flag();
384             }
385             Ok(FrontendReq::SET_FEATURES) => {
386                 let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
387 
388                 // Don't allow packed queues even if the device does. We don't handle them
389                 // properly yet at the protocol layer.
390                 // TODO: b/331466964 - Remove once support is added.
391                 msg.value &= !(1 << VIRTIO_F_RING_PACKED);
392 
393                 let res = self.backend.set_features(msg.value);
394                 self.acked_virtio_features = msg.value;
395                 self.update_reply_ack_flag();
396                 self.send_ack_message(&hdr, res.is_ok())?;
397                 res?;
398             }
399             Ok(FrontendReq::SET_MEM_TABLE) => {
400                 let res = self.set_mem_table(&hdr, size, &buf, files);
401                 self.send_ack_message(&hdr, res.is_ok())?;
402                 res?;
403             }
404             Ok(FrontendReq::SET_VRING_NUM) => {
405                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
406                 let res = self.backend.set_vring_num(msg.index, msg.num);
407                 self.send_ack_message(&hdr, res.is_ok())?;
408                 res?;
409             }
410             Ok(FrontendReq::SET_VRING_ADDR) => {
411                 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
412                 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
413                     Some(val) => val,
414                     None => return Err(Error::InvalidMessage),
415                 };
416                 let res = self.backend.set_vring_addr(
417                     msg.index,
418                     flags,
419                     msg.descriptor,
420                     msg.used,
421                     msg.available,
422                     msg.log,
423                 );
424                 self.send_ack_message(&hdr, res.is_ok())?;
425                 res?;
426             }
427             Ok(FrontendReq::SET_VRING_BASE) => {
428                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
429                 let res = self.backend.set_vring_base(msg.index, msg.num);
430                 self.send_ack_message(&hdr, res.is_ok())?;
431                 res?;
432             }
433             Ok(FrontendReq::GET_VRING_BASE) => {
434                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
435                 let reply = self.backend.get_vring_base(msg.index)?;
436                 self.send_reply_message(&hdr, &reply)?;
437             }
438             Ok(FrontendReq::SET_VRING_CALL) => {
439                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
440                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
441                 let res = self.backend.set_vring_call(index, file);
442                 self.send_ack_message(&hdr, res.is_ok())?;
443                 res?;
444             }
445             Ok(FrontendReq::SET_VRING_KICK) => {
446                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
447                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
448                 let res = self.backend.set_vring_kick(index, file);
449                 self.send_ack_message(&hdr, res.is_ok())?;
450                 res?;
451             }
452             Ok(FrontendReq::SET_VRING_ERR) => {
453                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
454                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
455                 let res = self.backend.set_vring_err(index, file);
456                 self.send_ack_message(&hdr, res.is_ok())?;
457                 res?;
458             }
459             Ok(FrontendReq::GET_PROTOCOL_FEATURES) => {
460                 self.check_request_size(&hdr, size, 0)?;
461                 let features = self.backend.get_protocol_features()?;
462                 let msg = VhostUserU64::new(features.bits());
463                 self.send_reply_message(&hdr, &msg)?;
464                 self.protocol_features = features;
465                 self.update_reply_ack_flag();
466             }
467             Ok(FrontendReq::SET_PROTOCOL_FEATURES) => {
468                 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
469                 let res = self.backend.set_protocol_features(msg.value);
470                 self.acked_protocol_features = msg.value;
471                 self.update_reply_ack_flag();
472                 self.send_ack_message(&hdr, res.is_ok())?;
473                 res?;
474             }
475             Ok(FrontendReq::GET_QUEUE_NUM) => {
476                 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
477                     return Err(Error::InvalidOperation);
478                 }
479                 self.check_request_size(&hdr, size, 0)?;
480                 let num = self.backend.get_queue_num()?;
481                 let msg = VhostUserU64::new(num);
482                 self.send_reply_message(&hdr, &msg)?;
483             }
484             Ok(FrontendReq::SET_VRING_ENABLE) => {
485                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
486                 if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
487                     return Err(Error::InvalidOperation);
488                 }
489                 let enable = match msg.num {
490                     1 => true,
491                     0 => false,
492                     _ => return Err(Error::InvalidParam),
493                 };
494 
495                 let res = self.backend.set_vring_enable(msg.index, enable);
496                 self.send_ack_message(&hdr, res.is_ok())?;
497                 res?;
498             }
499             Ok(FrontendReq::GET_CONFIG) => {
500                 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
501                     return Err(Error::InvalidOperation);
502                 }
503                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
504                 self.get_config(&hdr, &buf)?;
505             }
506             Ok(FrontendReq::SET_CONFIG) => {
507                 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
508                     return Err(Error::InvalidOperation);
509                 }
510                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
511                 let res = self.set_config(&buf);
512                 self.send_ack_message(&hdr, res.is_ok())?;
513                 res?;
514             }
515             Ok(FrontendReq::SET_BACKEND_REQ_FD) => {
516                 if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0
517                 {
518                     return Err(Error::InvalidOperation);
519                 }
520                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
521                 let res = self.set_backend_req_fd(files);
522                 self.send_ack_message(&hdr, res.is_ok())?;
523                 res?;
524             }
525             Ok(FrontendReq::GET_INFLIGHT_FD) => {
526                 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
527                     == 0
528                 {
529                     return Err(Error::InvalidOperation);
530                 }
531 
532                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
533                 let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
534                 let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
535                 self.connection.send_message(
536                     &reply_hdr,
537                     &inflight,
538                     Some(&[file.as_raw_descriptor()]),
539                 )?;
540             }
541             Ok(FrontendReq::SET_INFLIGHT_FD) => {
542                 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
543                     == 0
544                 {
545                     return Err(Error::InvalidOperation);
546                 }
547                 let file = into_single_file(files).ok_or(Error::IncorrectFds)?;
548                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
549                 let res = self.backend.set_inflight_fd(&msg, file);
550                 self.send_ack_message(&hdr, res.is_ok())?;
551                 res?;
552             }
553             Ok(FrontendReq::GET_MAX_MEM_SLOTS) => {
554                 if self.acked_protocol_features
555                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
556                     == 0
557                 {
558                     return Err(Error::InvalidOperation);
559                 }
560                 self.check_request_size(&hdr, size, 0)?;
561                 let num = self.backend.get_max_mem_slots()?;
562                 let msg = VhostUserU64::new(num);
563                 self.send_reply_message(&hdr, &msg)?;
564             }
565             Ok(FrontendReq::ADD_MEM_REG) => {
566                 if self.acked_protocol_features
567                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
568                     == 0
569                 {
570                     return Err(Error::InvalidOperation);
571                 }
572                 let file = into_single_file(files).ok_or(Error::InvalidParam)?;
573                 let msg =
574                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
575                 let res = self.backend.add_mem_region(&msg, file);
576                 self.send_ack_message(&hdr, res.is_ok())?;
577                 res?;
578             }
579             Ok(FrontendReq::REM_MEM_REG) => {
580                 if self.acked_protocol_features
581                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
582                     == 0
583                 {
584                     return Err(Error::InvalidOperation);
585                 }
586 
587                 let msg =
588                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
589                 let res = self.backend.remove_mem_region(&msg);
590                 self.send_ack_message(&hdr, res.is_ok())?;
591                 res?;
592             }
593             Ok(FrontendReq::SET_DEVICE_STATE_FD) => {
594                 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
595                     == 0
596                 {
597                     return Err(Error::InvalidOperation);
598                 }
599                 // Read request.
600                 let msg =
601                     self.extract_request_body::<DeviceStateTransferParameters>(&hdr, size, &buf)?;
602                 let transfer_direction = match msg.transfer_direction {
603                     0 => VhostUserTransferDirection::Save,
604                     1 => VhostUserTransferDirection::Load,
605                     _ => return Err(Error::InvalidMessage),
606                 };
607                 let migration_phase = match msg.migration_phase {
608                     0 => VhostUserMigrationPhase::Stopped,
609                     _ => return Err(Error::InvalidMessage),
610                 };
611                 // Call backend.
612                 let res = self.backend.set_device_state_fd(
613                     transfer_direction,
614                     migration_phase,
615                     files.into_iter().next().ok_or(Error::IncorrectFds)?,
616                 );
617                 // Send response.
618                 let (msg, fds) = match &res {
619                     Ok(None) => (VhostUserU64::new(0x100), None),
620                     Ok(Some(file)) => (VhostUserU64::new(0), Some(file.as_raw_descriptor())),
621                     // Just in case, set the "invalid FD" flag on error.
622                     Err(_) => (VhostUserU64::new(0x101), None),
623                 };
624                 let reply_hdr: VhostUserMsgHeader<FrontendReq> =
625                     self.new_reply_header::<VhostUserU64>(&hdr, 0)?;
626                 self.connection.send_message(
627                     &reply_hdr,
628                     &msg,
629                     fds.as_ref().map(std::slice::from_ref),
630                 )?;
631                 res?;
632             }
633             Ok(FrontendReq::CHECK_DEVICE_STATE) => {
634                 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits()
635                     == 0
636                 {
637                     return Err(Error::InvalidOperation);
638                 }
639                 let res = self.backend.check_device_state();
640                 let msg = VhostUserU64::new(if res.is_ok() { 0 } else { 1 });
641                 self.send_reply_message(&hdr, &msg)?;
642                 res?;
643             }
644             Ok(FrontendReq::GET_SHARED_MEMORY_REGIONS) => {
645                 let regions = self.backend.get_shared_memory_regions()?;
646                 let mut buf = Vec::new();
647                 let msg = VhostUserU64::new(regions.len() as u64);
648                 for r in regions {
649                     buf.extend_from_slice(r.as_bytes())
650                 }
651                 self.send_reply_with_payload(&hdr, &msg, buf.as_slice())?;
652             }
653             _ => {
654                 return Err(Error::InvalidMessage);
655             }
656         }
657         Ok(())
658     }
659 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<FrontendReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<FrontendReq>>660     fn new_reply_header<T: Sized>(
661         &self,
662         req: &VhostUserMsgHeader<FrontendReq>,
663         payload_size: usize,
664     ) -> Result<VhostUserMsgHeader<FrontendReq>> {
665         Ok(VhostUserMsgHeader::new(
666             req.get_code().map_err(|_| Error::InvalidMessage)?,
667             VhostUserHeaderFlag::REPLY.bits(),
668             (mem::size_of::<T>()
669                 .checked_add(payload_size)
670                 .ok_or(Error::OversizedMsg)?)
671             .try_into()
672             .map_err(Error::InvalidCastToInt)?,
673         ))
674     }
675 
676     /// Sends reply back to Vhost frontend in response to a message.
send_ack_message( &mut self, req: &VhostUserMsgHeader<FrontendReq>, success: bool, ) -> Result<()>677     fn send_ack_message(
678         &mut self,
679         req: &VhostUserMsgHeader<FrontendReq>,
680         success: bool,
681     ) -> Result<()> {
682         if self.reply_ack_enabled && req.is_need_reply() {
683             let hdr: VhostUserMsgHeader<FrontendReq> =
684                 self.new_reply_header::<VhostUserU64>(req, 0)?;
685             let val = if success { 0 } else { 1 };
686             let msg = VhostUserU64::new(val);
687             self.connection.send_message(&hdr, &msg, None)?;
688         }
689         Ok(())
690     }
691 
send_reply_message<T: Sized + AsBytes>( &mut self, req: &VhostUserMsgHeader<FrontendReq>, msg: &T, ) -> Result<()>692     fn send_reply_message<T: Sized + AsBytes>(
693         &mut self,
694         req: &VhostUserMsgHeader<FrontendReq>,
695         msg: &T,
696     ) -> Result<()> {
697         let hdr = self.new_reply_header::<T>(req, 0)?;
698         self.connection.send_message(&hdr, msg, None)?;
699         Ok(())
700     }
701 
send_reply_with_payload<T: Sized + AsBytes>( &mut self, req: &VhostUserMsgHeader<FrontendReq>, msg: &T, payload: &[u8], ) -> Result<()>702     fn send_reply_with_payload<T: Sized + AsBytes>(
703         &mut self,
704         req: &VhostUserMsgHeader<FrontendReq>,
705         msg: &T,
706         payload: &[u8],
707     ) -> Result<()> {
708         let hdr = self.new_reply_header::<T>(req, payload.len())?;
709         self.connection
710             .send_message_with_payload(&hdr, msg, payload, None)?;
711         Ok(())
712     }
713 
set_mem_table( &mut self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, buf: &[u8], files: Vec<File>, ) -> Result<()>714     fn set_mem_table(
715         &mut self,
716         hdr: &VhostUserMsgHeader<FrontendReq>,
717         size: usize,
718         buf: &[u8],
719         files: Vec<File>,
720     ) -> Result<()> {
721         self.check_request_size(hdr, size, hdr.get_size() as usize)?;
722 
723         let (msg, regions) =
724             Ref::<_, VhostUserMemory>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?;
725         if !msg.is_valid() {
726             return Err(Error::InvalidMessage);
727         }
728 
729         // validate number of fds matching number of memory regions
730         if files.len() != msg.num_regions as usize {
731             return Err(Error::InvalidMessage);
732         }
733 
734         let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::new_slice_from_prefix(
735             regions,
736             msg.num_regions as usize,
737         )
738         .ok_or(Error::InvalidMessage)?;
739         if !excess.is_empty() {
740             return Err(Error::InvalidMessage);
741         }
742 
743         // Validate memory regions
744         for region in regions.iter() {
745             if !region.is_valid() {
746                 return Err(Error::InvalidMessage);
747             }
748         }
749 
750         self.backend.set_mem_table(&regions, files)
751     }
752 
get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()>753     fn get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()> {
754         let (msg, payload) =
755             Ref::<_, VhostUserConfig>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?;
756         if !msg.is_valid() {
757             return Err(Error::InvalidMessage);
758         }
759         if payload.len() != msg.size as usize {
760             return Err(Error::InvalidMessage);
761         }
762         let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
763             Some(val) => val,
764             None => return Err(Error::InvalidMessage),
765         };
766         let res = self.backend.get_config(msg.offset, msg.size, flags);
767 
768         // The response payload size MUST match the request payload size on success. A zero length
769         // response is used to indicate an error.
770         match res {
771             Ok(ref buf) if buf.len() == msg.size as usize => {
772                 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
773                 self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
774             }
775             Ok(_) => {
776                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
777                 self.send_reply_message(hdr, &reply)?;
778             }
779             Err(_) => {
780                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
781                 self.send_reply_message(hdr, &reply)?;
782             }
783         }
784         Ok(())
785     }
786 
set_config(&mut self, buf: &[u8]) -> Result<()>787     fn set_config(&mut self, buf: &[u8]) -> Result<()> {
788         let (msg, payload) =
789             Ref::<_, VhostUserConfig>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?;
790         if !msg.is_valid() {
791             return Err(Error::InvalidMessage);
792         }
793         if payload.len() != msg.size as usize {
794             return Err(Error::InvalidMessage);
795         }
796         let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
797             Some(val) => val,
798             None => return Err(Error::InvalidMessage),
799         };
800 
801         self.backend.set_config(msg.offset, payload, flags)
802     }
803 
set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()>804     fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> {
805         let file = into_single_file(files).ok_or(Error::InvalidMessage)?;
806         let fd: SafeDescriptor = file.into();
807         let connection = Connection::try_from(fd).map_err(|_| Error::InvalidMessage)?;
808         self.backend.set_backend_req_fd(connection);
809         Ok(())
810     }
811 
812     /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a
813     /// Vring number and an fd.
handle_vring_fd_request( &mut self, buf: &[u8], files: Vec<File>, ) -> Result<(u8, Option<File>)>814     fn handle_vring_fd_request(
815         &mut self,
816         buf: &[u8],
817         files: Vec<File>,
818     ) -> Result<(u8, Option<File>)> {
819         let msg = VhostUserU64::read_from_prefix(buf).ok_or(Error::InvalidMessage)?;
820         if !msg.is_valid() {
821             return Err(Error::InvalidMessage);
822         }
823 
824         // Bits (0-7) of the payload contain the vring index. Bit 8 is the
825         // invalid FD flag (VHOST_USER_VRING_NOFD_MASK).
826         // This bit is set when there is no file descriptor
827         // in the ancillary data. This signals that polling will be used
828         // instead of waiting for the call.
829         // If Bit 8 is unset, the data must contain a file descriptor.
830         let has_fd = (msg.value & 0x100u64) == 0;
831 
832         let file = into_single_file(files);
833 
834         if has_fd && file.is_none() || !has_fd && file.is_some() {
835             return Err(Error::InvalidMessage);
836         }
837 
838         Ok((msg.value as u8, file))
839     }
840 
check_request_size( &self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, expected: usize, ) -> Result<()>841     fn check_request_size(
842         &self,
843         hdr: &VhostUserMsgHeader<FrontendReq>,
844         size: usize,
845         expected: usize,
846     ) -> Result<()> {
847         if hdr.get_size() as usize != expected
848             || hdr.is_reply()
849             || hdr.get_version() != 0x1
850             || size != expected
851         {
852             return Err(Error::InvalidMessage);
853         }
854         Ok(())
855     }
856 
check_attached_files( &self, hdr: &VhostUserMsgHeader<FrontendReq>, files: &[File], ) -> Result<()>857     fn check_attached_files(
858         &self,
859         hdr: &VhostUserMsgHeader<FrontendReq>,
860         files: &[File],
861     ) -> Result<()> {
862         match hdr.get_code() {
863             Ok(FrontendReq::SET_MEM_TABLE)
864             | Ok(FrontendReq::SET_VRING_CALL)
865             | Ok(FrontendReq::SET_VRING_KICK)
866             | Ok(FrontendReq::SET_VRING_ERR)
867             | Ok(FrontendReq::SET_LOG_BASE)
868             | Ok(FrontendReq::SET_LOG_FD)
869             | Ok(FrontendReq::SET_BACKEND_REQ_FD)
870             | Ok(FrontendReq::SET_INFLIGHT_FD)
871             | Ok(FrontendReq::ADD_MEM_REG)
872             | Ok(FrontendReq::SET_DEVICE_STATE_FD) => Ok(()),
873             Err(_) => Err(Error::InvalidMessage),
874             _ if !files.is_empty() => Err(Error::InvalidMessage),
875             _ => Ok(()),
876         }
877     }
878 
extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, buf: &[u8], ) -> Result<T>879     fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>(
880         &self,
881         hdr: &VhostUserMsgHeader<FrontendReq>,
882         size: usize,
883         buf: &[u8],
884     ) -> Result<T> {
885         self.check_request_size(hdr, size, mem::size_of::<T>())?;
886         T::read_from_prefix(buf)
887             .filter(T::is_valid)
888             .ok_or(Error::InvalidMessage)
889     }
890 
update_reply_ack_flag(&mut self)891     fn update_reply_ack_flag(&mut self) {
892         let pflag = VhostUserProtocolFeatures::REPLY_ACK;
893         self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0
894             && self.protocol_features.contains(pflag)
895             && (self.acked_protocol_features & pflag.bits()) != 0;
896     }
897 }
898 
899 impl<S: Backend> AsRawDescriptor for BackendServer<S> {
as_raw_descriptor(&self) -> RawDescriptor900     fn as_raw_descriptor(&self) -> RawDescriptor {
901         // TODO(b/221882601): figure out if this used for polling.
902         self.connection.as_raw_descriptor()
903     }
904 }
905 
906 #[cfg(test)]
907 mod tests {
908     use base::INVALID_DESCRIPTOR;
909 
910     use super::*;
911     use crate::test_backend::TestBackend;
912     use crate::Connection;
913 
914     #[test]
test_backend_server_new()915     fn test_backend_server_new() {
916         let (p1, _p2) = Connection::pair().unwrap();
917         let backend = TestBackend::new();
918         let handler = BackendServer::new(p1, backend);
919 
920         assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
921     }
922 }
923