1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use std::fs::File; 5 use std::mem; 6 7 use base::AsRawDescriptor; 8 use base::RawDescriptor; 9 use base::SafeDescriptor; 10 use zerocopy::AsBytes; 11 use zerocopy::FromBytes; 12 use zerocopy::Ref; 13 14 use crate::into_single_file; 15 use crate::message::*; 16 use crate::BackendReq; 17 use crate::Connection; 18 use crate::Error; 19 use crate::FrontendReq; 20 use crate::Result; 21 22 /// Trait for vhost-user backends. 23 /// 24 /// Each method corresponds to a vhost-user protocol method. See the specification for details. 25 #[allow(missing_docs)] 26 pub trait Backend { set_owner(&mut self) -> Result<()>27 fn set_owner(&mut self) -> Result<()>; reset_owner(&mut self) -> Result<()>28 fn reset_owner(&mut self) -> Result<()>; get_features(&mut self) -> Result<u64>29 fn get_features(&mut self) -> Result<u64>; set_features(&mut self, features: u64) -> Result<()>30 fn set_features(&mut self, features: u64) -> Result<()>; set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>31 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; set_vring_num(&mut self, index: u32, num: u32) -> Result<()>32 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>33 fn set_vring_addr( 34 &mut self, 35 index: u32, 36 flags: VhostUserVringAddrFlags, 37 descriptor: u64, 38 used: u64, 39 available: u64, 40 log: u64, 41 ) -> Result<()>; 42 // TODO: b/331466964 - Argument type is wrong for packed queues. set_vring_base(&mut self, index: u32, base: u32) -> Result<()>43 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; 44 // TODO: b/331466964 - Return type is wrong for packed queues. get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>45 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>46 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>47 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>48 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>; 49 get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>50 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; set_protocol_features(&mut self, features: u64) -> Result<()>51 fn set_protocol_features(&mut self, features: u64) -> Result<()>; get_queue_num(&mut self) -> Result<u64>52 fn get_queue_num(&mut self) -> Result<u64>; set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>53 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>; get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>54 fn get_config( 55 &mut self, 56 offset: u32, 57 size: u32, 58 flags: VhostUserConfigFlags, 59 ) -> Result<Vec<u8>>; set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>60 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>)61 fn set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>) {} get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>62 fn get_inflight_fd( 63 &mut self, 64 inflight: &VhostUserInflight, 65 ) -> Result<(VhostUserInflight, File)>; set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>66 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; get_max_mem_slots(&mut self) -> Result<u64>67 fn get_max_mem_slots(&mut self) -> Result<u64>; add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>68 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>69 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; set_device_state_fd( &mut self, transfer_direction: VhostUserTransferDirection, migration_phase: VhostUserMigrationPhase, fd: File, ) -> Result<Option<File>>70 fn set_device_state_fd( 71 &mut self, 72 transfer_direction: VhostUserTransferDirection, 73 migration_phase: VhostUserMigrationPhase, 74 fd: File, 75 ) -> Result<Option<File>>; check_device_state(&mut self) -> Result<()>76 fn check_device_state(&mut self) -> Result<()>; get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>77 fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>; 78 } 79 80 impl<T> Backend for T 81 where 82 T: AsMut<dyn Backend>, 83 { set_owner(&mut self) -> Result<()>84 fn set_owner(&mut self) -> Result<()> { 85 self.as_mut().set_owner() 86 } 87 reset_owner(&mut self) -> Result<()>88 fn reset_owner(&mut self) -> Result<()> { 89 self.as_mut().reset_owner() 90 } 91 get_features(&mut self) -> Result<u64>92 fn get_features(&mut self) -> Result<u64> { 93 self.as_mut().get_features() 94 } 95 set_features(&mut self, features: u64) -> Result<()>96 fn set_features(&mut self, features: u64) -> Result<()> { 97 self.as_mut().set_features(features) 98 } 99 set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>100 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { 101 self.as_mut().set_mem_table(ctx, files) 102 } 103 set_vring_num(&mut self, index: u32, num: u32) -> Result<()>104 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> { 105 self.as_mut().set_vring_num(index, num) 106 } 107 set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>108 fn set_vring_addr( 109 &mut self, 110 index: u32, 111 flags: VhostUserVringAddrFlags, 112 descriptor: u64, 113 used: u64, 114 available: u64, 115 log: u64, 116 ) -> Result<()> { 117 self.as_mut() 118 .set_vring_addr(index, flags, descriptor, used, available, log) 119 } 120 set_vring_base(&mut self, index: u32, base: u32) -> Result<()>121 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> { 122 self.as_mut().set_vring_base(index, base) 123 } 124 get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>125 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> { 126 self.as_mut().get_vring_base(index) 127 } 128 set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>129 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> { 130 self.as_mut().set_vring_kick(index, fd) 131 } 132 set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>133 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> { 134 self.as_mut().set_vring_call(index, fd) 135 } 136 set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>137 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> { 138 self.as_mut().set_vring_err(index, fd) 139 } 140 get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>141 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { 142 self.as_mut().get_protocol_features() 143 } 144 set_protocol_features(&mut self, features: u64) -> Result<()>145 fn set_protocol_features(&mut self, features: u64) -> Result<()> { 146 self.as_mut().set_protocol_features(features) 147 } 148 get_queue_num(&mut self) -> Result<u64>149 fn get_queue_num(&mut self) -> Result<u64> { 150 self.as_mut().get_queue_num() 151 } 152 set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>153 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> { 154 self.as_mut().set_vring_enable(index, enable) 155 } 156 get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>157 fn get_config( 158 &mut self, 159 offset: u32, 160 size: u32, 161 flags: VhostUserConfigFlags, 162 ) -> Result<Vec<u8>> { 163 self.as_mut().get_config(offset, size, flags) 164 } 165 set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>166 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { 167 self.as_mut().set_config(offset, buf, flags) 168 } 169 set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>)170 fn set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>) { 171 self.as_mut().set_backend_req_fd(vu_req) 172 } 173 get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>174 fn get_inflight_fd( 175 &mut self, 176 inflight: &VhostUserInflight, 177 ) -> Result<(VhostUserInflight, File)> { 178 self.as_mut().get_inflight_fd(inflight) 179 } 180 set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>181 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> { 182 self.as_mut().set_inflight_fd(inflight, file) 183 } 184 get_max_mem_slots(&mut self) -> Result<u64>185 fn get_max_mem_slots(&mut self) -> Result<u64> { 186 self.as_mut().get_max_mem_slots() 187 } 188 add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>189 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { 190 self.as_mut().add_mem_region(region, fd) 191 } 192 remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>193 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> { 194 self.as_mut().remove_mem_region(region) 195 } 196 set_device_state_fd( &mut self, transfer_direction: VhostUserTransferDirection, migration_phase: VhostUserMigrationPhase, fd: File, ) -> Result<Option<File>>197 fn set_device_state_fd( 198 &mut self, 199 transfer_direction: VhostUserTransferDirection, 200 migration_phase: VhostUserMigrationPhase, 201 fd: File, 202 ) -> Result<Option<File>> { 203 self.as_mut() 204 .set_device_state_fd(transfer_direction, migration_phase, fd) 205 } 206 check_device_state(&mut self) -> Result<()>207 fn check_device_state(&mut self) -> Result<()> { 208 self.as_mut().check_device_state() 209 } 210 get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>211 fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>> { 212 self.as_mut().get_shared_memory_regions() 213 } 214 } 215 216 /// Handles requests from a vhost-user connection by dispatching them to [[Backend]] methods. 217 pub struct BackendServer<S: Backend> { 218 /// Underlying connection for communication. 219 connection: Connection<FrontendReq>, 220 // the vhost-user backend device object 221 backend: S, 222 223 virtio_features: u64, 224 acked_virtio_features: u64, 225 protocol_features: VhostUserProtocolFeatures, 226 acked_protocol_features: u64, 227 228 /// Sending ack for messages without payload. 229 reply_ack_enabled: bool, 230 } 231 232 impl<S: Backend> AsRef<S> for BackendServer<S> { as_ref(&self) -> &S233 fn as_ref(&self) -> &S { 234 &self.backend 235 } 236 } 237 238 impl<S: Backend> BackendServer<S> { new(connection: Connection<FrontendReq>, backend: S) -> Self239 pub fn new(connection: Connection<FrontendReq>, backend: S) -> Self { 240 BackendServer { 241 connection, 242 backend, 243 virtio_features: 0, 244 acked_virtio_features: 0, 245 protocol_features: VhostUserProtocolFeatures::empty(), 246 acked_protocol_features: 0, 247 reply_ack_enabled: false, 248 } 249 } 250 251 /// Receives and validates a vhost-user message header and optional files. 252 /// 253 /// Since the length of vhost-user messages are different among message types, regular 254 /// vhost-user messages are sent via an underlying communication channel in stream mode. 255 /// (e.g. `SOCK_STREAM` in UNIX) 256 /// So, the logic of receiving and handling a message consists of the following steps: 257 /// 258 /// 1. Receives a message header and optional attached file. 259 /// 2. Validates the message header. 260 /// 3. Check if optional payloads is expected. 261 /// 4. Wait for the optional payloads. 262 /// 5. Receives optional payloads. 263 /// 6. Processes the message. 264 /// 265 /// This method [`BackendServer::recv_header()`] is in charge of the step (1) and (2), 266 /// [`BackendServer::needs_wait_for_payload()`] is (3), and 267 /// [`BackendServer::process_message()`] is (5) and (6). We need to have the three method 268 /// separately for multi-platform supports; [`BackendServer::recv_header()`] and 269 /// [`BackendServer::process_message()`] need to be separated because the way of waiting for 270 /// incoming messages differs between Unix and Windows so it's the caller's responsibility to 271 /// wait before [`BackendServer::process_message()`]. 272 /// 273 /// Note that some vhost-user protocol variant such as VVU doesn't assume stream mode. In this 274 /// case, a message header and its body are sent together so the step (4) is skipped. We handle 275 /// this case in [`BackendServer::needs_wait_for_payload()`]. 276 /// 277 /// The following pseudo code describes how a caller should process incoming vhost-user 278 /// messages: 279 /// ```ignore 280 /// loop { 281 /// // block until a message header comes. 282 /// // The actual code differs, depending on platforms. 283 /// connection.wait_readable().unwrap(); 284 /// 285 /// // (1) and (2) 286 /// let (hdr, files) = backend_server.recv_header(); 287 /// 288 /// // (3) 289 /// if backend_server.needs_wait_for_payload(&hdr) { 290 /// // (4) block until a payload comes if needed. 291 /// connection.wait_readable().unwrap(); 292 /// } 293 /// 294 /// // (5) and (6) 295 /// backend_server.process_message(&hdr, &files).unwrap(); 296 /// } 297 /// ``` recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)>298 pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)> { 299 // The underlying communication channel is a Unix domain socket in 300 // stream mode, and recvmsg() is a little tricky here. To successfully 301 // receive attached file descriptors, we need to receive messages and 302 // corresponding attached file descriptors in this way: 303 // . recv messsage header and optional attached file 304 // . validate message header 305 // . recv optional message body and payload according size field in 306 // message header 307 // . validate message body and optional payload 308 let (hdr, files) = match self.connection.recv_header() { 309 Ok((hdr, files)) => (hdr, files), 310 Err(Error::Disconnect) => { 311 // If the client closed the connection before sending a header, this should be 312 // handled as a legal exit. 313 return Err(Error::ClientExit); 314 } 315 Err(e) => { 316 return Err(e); 317 } 318 }; 319 320 self.check_attached_files(&hdr, &files)?; 321 322 Ok((hdr, files)) 323 } 324 325 /// Returns whether the caller needs to wait for the incoming message before calling 326 /// [`BackendServer::process_message`]. 327 /// 328 /// See [`BackendServer::recv_header`]'s doc comment for the usage. needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool329 pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool { 330 // Since the vhost-user protocol uses stream mode, we need to wait until an additional 331 // payload is available if exists. 332 hdr.get_size() != 0 333 } 334 335 /// Main entrance to request from the communication channel. 336 /// 337 /// Receive and handle one incoming request message from the frontend. 338 /// See [`BackendServer::recv_header`]'s doc comment for the usage. 339 /// 340 /// # Return: 341 /// * `Ok(())`: one request was successfully handled. 342 /// * `Err(ClientExit)`: the frontend closed the connection properly. This isn't an actual 343 /// failure. 344 /// * `Err(Disconnect)`: the connection was closed unexpectedly. 345 /// * `Err(InvalidMessage)`: the vmm sent a illegal message. 346 /// * other errors: failed to handle a request. process_message( &mut self, hdr: VhostUserMsgHeader<FrontendReq>, files: Vec<File>, ) -> Result<()>347 pub fn process_message( 348 &mut self, 349 hdr: VhostUserMsgHeader<FrontendReq>, 350 files: Vec<File>, 351 ) -> Result<()> { 352 let buf = self.connection.recv_body_bytes(&hdr)?; 353 let size = buf.len(); 354 355 // TODO: The error handling here is inconsistent. Sometimes we report the error to the 356 // client and keep going, sometimes we report the error and then close the connection, 357 // sometimes we just close the connection. 358 match hdr.get_code() { 359 Ok(FrontendReq::SET_OWNER) => { 360 self.check_request_size(&hdr, size, 0)?; 361 let res = self.backend.set_owner(); 362 self.send_ack_message(&hdr, res.is_ok())?; 363 res?; 364 } 365 Ok(FrontendReq::RESET_OWNER) => { 366 self.check_request_size(&hdr, size, 0)?; 367 let res = self.backend.reset_owner(); 368 self.send_ack_message(&hdr, res.is_ok())?; 369 res?; 370 } 371 Ok(FrontendReq::GET_FEATURES) => { 372 self.check_request_size(&hdr, size, 0)?; 373 let mut features = self.backend.get_features()?; 374 375 // Don't advertise packed queues even if the device does. We don't handle them 376 // properly yet at the protocol layer. 377 // TODO: b/331466964 - Remove once support is added. 378 features &= !(1 << VIRTIO_F_RING_PACKED); 379 380 let msg = VhostUserU64::new(features); 381 self.send_reply_message(&hdr, &msg)?; 382 self.virtio_features = features; 383 self.update_reply_ack_flag(); 384 } 385 Ok(FrontendReq::SET_FEATURES) => { 386 let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 387 388 // Don't allow packed queues even if the device does. We don't handle them 389 // properly yet at the protocol layer. 390 // TODO: b/331466964 - Remove once support is added. 391 msg.value &= !(1 << VIRTIO_F_RING_PACKED); 392 393 let res = self.backend.set_features(msg.value); 394 self.acked_virtio_features = msg.value; 395 self.update_reply_ack_flag(); 396 self.send_ack_message(&hdr, res.is_ok())?; 397 res?; 398 } 399 Ok(FrontendReq::SET_MEM_TABLE) => { 400 let res = self.set_mem_table(&hdr, size, &buf, files); 401 self.send_ack_message(&hdr, res.is_ok())?; 402 res?; 403 } 404 Ok(FrontendReq::SET_VRING_NUM) => { 405 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 406 let res = self.backend.set_vring_num(msg.index, msg.num); 407 self.send_ack_message(&hdr, res.is_ok())?; 408 res?; 409 } 410 Ok(FrontendReq::SET_VRING_ADDR) => { 411 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?; 412 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { 413 Some(val) => val, 414 None => return Err(Error::InvalidMessage), 415 }; 416 let res = self.backend.set_vring_addr( 417 msg.index, 418 flags, 419 msg.descriptor, 420 msg.used, 421 msg.available, 422 msg.log, 423 ); 424 self.send_ack_message(&hdr, res.is_ok())?; 425 res?; 426 } 427 Ok(FrontendReq::SET_VRING_BASE) => { 428 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 429 let res = self.backend.set_vring_base(msg.index, msg.num); 430 self.send_ack_message(&hdr, res.is_ok())?; 431 res?; 432 } 433 Ok(FrontendReq::GET_VRING_BASE) => { 434 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 435 let reply = self.backend.get_vring_base(msg.index)?; 436 self.send_reply_message(&hdr, &reply)?; 437 } 438 Ok(FrontendReq::SET_VRING_CALL) => { 439 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 440 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 441 let res = self.backend.set_vring_call(index, file); 442 self.send_ack_message(&hdr, res.is_ok())?; 443 res?; 444 } 445 Ok(FrontendReq::SET_VRING_KICK) => { 446 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 447 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 448 let res = self.backend.set_vring_kick(index, file); 449 self.send_ack_message(&hdr, res.is_ok())?; 450 res?; 451 } 452 Ok(FrontendReq::SET_VRING_ERR) => { 453 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 454 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 455 let res = self.backend.set_vring_err(index, file); 456 self.send_ack_message(&hdr, res.is_ok())?; 457 res?; 458 } 459 Ok(FrontendReq::GET_PROTOCOL_FEATURES) => { 460 self.check_request_size(&hdr, size, 0)?; 461 let features = self.backend.get_protocol_features()?; 462 let msg = VhostUserU64::new(features.bits()); 463 self.send_reply_message(&hdr, &msg)?; 464 self.protocol_features = features; 465 self.update_reply_ack_flag(); 466 } 467 Ok(FrontendReq::SET_PROTOCOL_FEATURES) => { 468 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 469 let res = self.backend.set_protocol_features(msg.value); 470 self.acked_protocol_features = msg.value; 471 self.update_reply_ack_flag(); 472 self.send_ack_message(&hdr, res.is_ok())?; 473 res?; 474 } 475 Ok(FrontendReq::GET_QUEUE_NUM) => { 476 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { 477 return Err(Error::InvalidOperation); 478 } 479 self.check_request_size(&hdr, size, 0)?; 480 let num = self.backend.get_queue_num()?; 481 let msg = VhostUserU64::new(num); 482 self.send_reply_message(&hdr, &msg)?; 483 } 484 Ok(FrontendReq::SET_VRING_ENABLE) => { 485 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 486 if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 { 487 return Err(Error::InvalidOperation); 488 } 489 let enable = match msg.num { 490 1 => true, 491 0 => false, 492 _ => return Err(Error::InvalidParam), 493 }; 494 495 let res = self.backend.set_vring_enable(msg.index, enable); 496 self.send_ack_message(&hdr, res.is_ok())?; 497 res?; 498 } 499 Ok(FrontendReq::GET_CONFIG) => { 500 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { 501 return Err(Error::InvalidOperation); 502 } 503 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 504 self.get_config(&hdr, &buf)?; 505 } 506 Ok(FrontendReq::SET_CONFIG) => { 507 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { 508 return Err(Error::InvalidOperation); 509 } 510 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 511 let res = self.set_config(&buf); 512 self.send_ack_message(&hdr, res.is_ok())?; 513 res?; 514 } 515 Ok(FrontendReq::SET_BACKEND_REQ_FD) => { 516 if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0 517 { 518 return Err(Error::InvalidOperation); 519 } 520 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 521 let res = self.set_backend_req_fd(files); 522 self.send_ack_message(&hdr, res.is_ok())?; 523 res?; 524 } 525 Ok(FrontendReq::GET_INFLIGHT_FD) => { 526 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() 527 == 0 528 { 529 return Err(Error::InvalidOperation); 530 } 531 532 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 533 let (inflight, file) = self.backend.get_inflight_fd(&msg)?; 534 let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?; 535 self.connection.send_message( 536 &reply_hdr, 537 &inflight, 538 Some(&[file.as_raw_descriptor()]), 539 )?; 540 } 541 Ok(FrontendReq::SET_INFLIGHT_FD) => { 542 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() 543 == 0 544 { 545 return Err(Error::InvalidOperation); 546 } 547 let file = into_single_file(files).ok_or(Error::IncorrectFds)?; 548 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 549 let res = self.backend.set_inflight_fd(&msg, file); 550 self.send_ack_message(&hdr, res.is_ok())?; 551 res?; 552 } 553 Ok(FrontendReq::GET_MAX_MEM_SLOTS) => { 554 if self.acked_protocol_features 555 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 556 == 0 557 { 558 return Err(Error::InvalidOperation); 559 } 560 self.check_request_size(&hdr, size, 0)?; 561 let num = self.backend.get_max_mem_slots()?; 562 let msg = VhostUserU64::new(num); 563 self.send_reply_message(&hdr, &msg)?; 564 } 565 Ok(FrontendReq::ADD_MEM_REG) => { 566 if self.acked_protocol_features 567 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 568 == 0 569 { 570 return Err(Error::InvalidOperation); 571 } 572 let file = into_single_file(files).ok_or(Error::InvalidParam)?; 573 let msg = 574 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 575 let res = self.backend.add_mem_region(&msg, file); 576 self.send_ack_message(&hdr, res.is_ok())?; 577 res?; 578 } 579 Ok(FrontendReq::REM_MEM_REG) => { 580 if self.acked_protocol_features 581 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 582 == 0 583 { 584 return Err(Error::InvalidOperation); 585 } 586 587 let msg = 588 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 589 let res = self.backend.remove_mem_region(&msg); 590 self.send_ack_message(&hdr, res.is_ok())?; 591 res?; 592 } 593 Ok(FrontendReq::SET_DEVICE_STATE_FD) => { 594 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits() 595 == 0 596 { 597 return Err(Error::InvalidOperation); 598 } 599 // Read request. 600 let msg = 601 self.extract_request_body::<DeviceStateTransferParameters>(&hdr, size, &buf)?; 602 let transfer_direction = match msg.transfer_direction { 603 0 => VhostUserTransferDirection::Save, 604 1 => VhostUserTransferDirection::Load, 605 _ => return Err(Error::InvalidMessage), 606 }; 607 let migration_phase = match msg.migration_phase { 608 0 => VhostUserMigrationPhase::Stopped, 609 _ => return Err(Error::InvalidMessage), 610 }; 611 // Call backend. 612 let res = self.backend.set_device_state_fd( 613 transfer_direction, 614 migration_phase, 615 files.into_iter().next().ok_or(Error::IncorrectFds)?, 616 ); 617 // Send response. 618 let (msg, fds) = match &res { 619 Ok(None) => (VhostUserU64::new(0x100), None), 620 Ok(Some(file)) => (VhostUserU64::new(0), Some(file.as_raw_descriptor())), 621 // Just in case, set the "invalid FD" flag on error. 622 Err(_) => (VhostUserU64::new(0x101), None), 623 }; 624 let reply_hdr: VhostUserMsgHeader<FrontendReq> = 625 self.new_reply_header::<VhostUserU64>(&hdr, 0)?; 626 self.connection.send_message( 627 &reply_hdr, 628 &msg, 629 fds.as_ref().map(std::slice::from_ref), 630 )?; 631 res?; 632 } 633 Ok(FrontendReq::CHECK_DEVICE_STATE) => { 634 if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits() 635 == 0 636 { 637 return Err(Error::InvalidOperation); 638 } 639 let res = self.backend.check_device_state(); 640 let msg = VhostUserU64::new(if res.is_ok() { 0 } else { 1 }); 641 self.send_reply_message(&hdr, &msg)?; 642 res?; 643 } 644 Ok(FrontendReq::GET_SHARED_MEMORY_REGIONS) => { 645 let regions = self.backend.get_shared_memory_regions()?; 646 let mut buf = Vec::new(); 647 let msg = VhostUserU64::new(regions.len() as u64); 648 for r in regions { 649 buf.extend_from_slice(r.as_bytes()) 650 } 651 self.send_reply_with_payload(&hdr, &msg, buf.as_slice())?; 652 } 653 _ => { 654 return Err(Error::InvalidMessage); 655 } 656 } 657 Ok(()) 658 } 659 new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<FrontendReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<FrontendReq>>660 fn new_reply_header<T: Sized>( 661 &self, 662 req: &VhostUserMsgHeader<FrontendReq>, 663 payload_size: usize, 664 ) -> Result<VhostUserMsgHeader<FrontendReq>> { 665 Ok(VhostUserMsgHeader::new( 666 req.get_code().map_err(|_| Error::InvalidMessage)?, 667 VhostUserHeaderFlag::REPLY.bits(), 668 (mem::size_of::<T>() 669 .checked_add(payload_size) 670 .ok_or(Error::OversizedMsg)?) 671 .try_into() 672 .map_err(Error::InvalidCastToInt)?, 673 )) 674 } 675 676 /// Sends reply back to Vhost frontend in response to a message. send_ack_message( &mut self, req: &VhostUserMsgHeader<FrontendReq>, success: bool, ) -> Result<()>677 fn send_ack_message( 678 &mut self, 679 req: &VhostUserMsgHeader<FrontendReq>, 680 success: bool, 681 ) -> Result<()> { 682 if self.reply_ack_enabled && req.is_need_reply() { 683 let hdr: VhostUserMsgHeader<FrontendReq> = 684 self.new_reply_header::<VhostUserU64>(req, 0)?; 685 let val = if success { 0 } else { 1 }; 686 let msg = VhostUserU64::new(val); 687 self.connection.send_message(&hdr, &msg, None)?; 688 } 689 Ok(()) 690 } 691 send_reply_message<T: Sized + AsBytes>( &mut self, req: &VhostUserMsgHeader<FrontendReq>, msg: &T, ) -> Result<()>692 fn send_reply_message<T: Sized + AsBytes>( 693 &mut self, 694 req: &VhostUserMsgHeader<FrontendReq>, 695 msg: &T, 696 ) -> Result<()> { 697 let hdr = self.new_reply_header::<T>(req, 0)?; 698 self.connection.send_message(&hdr, msg, None)?; 699 Ok(()) 700 } 701 send_reply_with_payload<T: Sized + AsBytes>( &mut self, req: &VhostUserMsgHeader<FrontendReq>, msg: &T, payload: &[u8], ) -> Result<()>702 fn send_reply_with_payload<T: Sized + AsBytes>( 703 &mut self, 704 req: &VhostUserMsgHeader<FrontendReq>, 705 msg: &T, 706 payload: &[u8], 707 ) -> Result<()> { 708 let hdr = self.new_reply_header::<T>(req, payload.len())?; 709 self.connection 710 .send_message_with_payload(&hdr, msg, payload, None)?; 711 Ok(()) 712 } 713 set_mem_table( &mut self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, buf: &[u8], files: Vec<File>, ) -> Result<()>714 fn set_mem_table( 715 &mut self, 716 hdr: &VhostUserMsgHeader<FrontendReq>, 717 size: usize, 718 buf: &[u8], 719 files: Vec<File>, 720 ) -> Result<()> { 721 self.check_request_size(hdr, size, hdr.get_size() as usize)?; 722 723 let (msg, regions) = 724 Ref::<_, VhostUserMemory>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?; 725 if !msg.is_valid() { 726 return Err(Error::InvalidMessage); 727 } 728 729 // validate number of fds matching number of memory regions 730 if files.len() != msg.num_regions as usize { 731 return Err(Error::InvalidMessage); 732 } 733 734 let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::new_slice_from_prefix( 735 regions, 736 msg.num_regions as usize, 737 ) 738 .ok_or(Error::InvalidMessage)?; 739 if !excess.is_empty() { 740 return Err(Error::InvalidMessage); 741 } 742 743 // Validate memory regions 744 for region in regions.iter() { 745 if !region.is_valid() { 746 return Err(Error::InvalidMessage); 747 } 748 } 749 750 self.backend.set_mem_table(®ions, files) 751 } 752 get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()>753 fn get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()> { 754 let (msg, payload) = 755 Ref::<_, VhostUserConfig>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?; 756 if !msg.is_valid() { 757 return Err(Error::InvalidMessage); 758 } 759 if payload.len() != msg.size as usize { 760 return Err(Error::InvalidMessage); 761 } 762 let flags = match VhostUserConfigFlags::from_bits(msg.flags) { 763 Some(val) => val, 764 None => return Err(Error::InvalidMessage), 765 }; 766 let res = self.backend.get_config(msg.offset, msg.size, flags); 767 768 // The response payload size MUST match the request payload size on success. A zero length 769 // response is used to indicate an error. 770 match res { 771 Ok(ref buf) if buf.len() == msg.size as usize => { 772 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags); 773 self.send_reply_with_payload(hdr, &reply, buf.as_slice())?; 774 } 775 Ok(_) => { 776 let reply = VhostUserConfig::new(msg.offset, 0, flags); 777 self.send_reply_message(hdr, &reply)?; 778 } 779 Err(_) => { 780 let reply = VhostUserConfig::new(msg.offset, 0, flags); 781 self.send_reply_message(hdr, &reply)?; 782 } 783 } 784 Ok(()) 785 } 786 set_config(&mut self, buf: &[u8]) -> Result<()>787 fn set_config(&mut self, buf: &[u8]) -> Result<()> { 788 let (msg, payload) = 789 Ref::<_, VhostUserConfig>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?; 790 if !msg.is_valid() { 791 return Err(Error::InvalidMessage); 792 } 793 if payload.len() != msg.size as usize { 794 return Err(Error::InvalidMessage); 795 } 796 let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) { 797 Some(val) => val, 798 None => return Err(Error::InvalidMessage), 799 }; 800 801 self.backend.set_config(msg.offset, payload, flags) 802 } 803 set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()>804 fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> { 805 let file = into_single_file(files).ok_or(Error::InvalidMessage)?; 806 let fd: SafeDescriptor = file.into(); 807 let connection = Connection::try_from(fd).map_err(|_| Error::InvalidMessage)?; 808 self.backend.set_backend_req_fd(connection); 809 Ok(()) 810 } 811 812 /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a 813 /// Vring number and an fd. handle_vring_fd_request( &mut self, buf: &[u8], files: Vec<File>, ) -> Result<(u8, Option<File>)>814 fn handle_vring_fd_request( 815 &mut self, 816 buf: &[u8], 817 files: Vec<File>, 818 ) -> Result<(u8, Option<File>)> { 819 let msg = VhostUserU64::read_from_prefix(buf).ok_or(Error::InvalidMessage)?; 820 if !msg.is_valid() { 821 return Err(Error::InvalidMessage); 822 } 823 824 // Bits (0-7) of the payload contain the vring index. Bit 8 is the 825 // invalid FD flag (VHOST_USER_VRING_NOFD_MASK). 826 // This bit is set when there is no file descriptor 827 // in the ancillary data. This signals that polling will be used 828 // instead of waiting for the call. 829 // If Bit 8 is unset, the data must contain a file descriptor. 830 let has_fd = (msg.value & 0x100u64) == 0; 831 832 let file = into_single_file(files); 833 834 if has_fd && file.is_none() || !has_fd && file.is_some() { 835 return Err(Error::InvalidMessage); 836 } 837 838 Ok((msg.value as u8, file)) 839 } 840 check_request_size( &self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, expected: usize, ) -> Result<()>841 fn check_request_size( 842 &self, 843 hdr: &VhostUserMsgHeader<FrontendReq>, 844 size: usize, 845 expected: usize, 846 ) -> Result<()> { 847 if hdr.get_size() as usize != expected 848 || hdr.is_reply() 849 || hdr.get_version() != 0x1 850 || size != expected 851 { 852 return Err(Error::InvalidMessage); 853 } 854 Ok(()) 855 } 856 check_attached_files( &self, hdr: &VhostUserMsgHeader<FrontendReq>, files: &[File], ) -> Result<()>857 fn check_attached_files( 858 &self, 859 hdr: &VhostUserMsgHeader<FrontendReq>, 860 files: &[File], 861 ) -> Result<()> { 862 match hdr.get_code() { 863 Ok(FrontendReq::SET_MEM_TABLE) 864 | Ok(FrontendReq::SET_VRING_CALL) 865 | Ok(FrontendReq::SET_VRING_KICK) 866 | Ok(FrontendReq::SET_VRING_ERR) 867 | Ok(FrontendReq::SET_LOG_BASE) 868 | Ok(FrontendReq::SET_LOG_FD) 869 | Ok(FrontendReq::SET_BACKEND_REQ_FD) 870 | Ok(FrontendReq::SET_INFLIGHT_FD) 871 | Ok(FrontendReq::ADD_MEM_REG) 872 | Ok(FrontendReq::SET_DEVICE_STATE_FD) => Ok(()), 873 Err(_) => Err(Error::InvalidMessage), 874 _ if !files.is_empty() => Err(Error::InvalidMessage), 875 _ => Ok(()), 876 } 877 } 878 extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, buf: &[u8], ) -> Result<T>879 fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>( 880 &self, 881 hdr: &VhostUserMsgHeader<FrontendReq>, 882 size: usize, 883 buf: &[u8], 884 ) -> Result<T> { 885 self.check_request_size(hdr, size, mem::size_of::<T>())?; 886 T::read_from_prefix(buf) 887 .filter(T::is_valid) 888 .ok_or(Error::InvalidMessage) 889 } 890 update_reply_ack_flag(&mut self)891 fn update_reply_ack_flag(&mut self) { 892 let pflag = VhostUserProtocolFeatures::REPLY_ACK; 893 self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0 894 && self.protocol_features.contains(pflag) 895 && (self.acked_protocol_features & pflag.bits()) != 0; 896 } 897 } 898 899 impl<S: Backend> AsRawDescriptor for BackendServer<S> { as_raw_descriptor(&self) -> RawDescriptor900 fn as_raw_descriptor(&self) -> RawDescriptor { 901 // TODO(b/221882601): figure out if this used for polling. 902 self.connection.as_raw_descriptor() 903 } 904 } 905 906 #[cfg(test)] 907 mod tests { 908 use base::INVALID_DESCRIPTOR; 909 910 use super::*; 911 use crate::test_backend::TestBackend; 912 use crate::Connection; 913 914 #[test] test_backend_server_new()915 fn test_backend_server_new() { 916 let (p1, _p2) = Connection::pair().unwrap(); 917 let backend = TestBackend::new(); 918 let handler = BackendServer::new(p1, backend); 919 920 assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR); 921 } 922 } 923