1 // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
2 
3 use std::{
4     io::{ErrorKind, Read, Write},
5     num::Wrapping,
6     os::unix::prelude::{AsRawFd, RawFd},
7 };
8 
9 use log::{error, info};
10 use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
11 use vm_memory::{bitmap::BitmapSlice, Bytes, VolatileSlice};
12 
13 use crate::{
14     rxops::*,
15     rxqueue::*,
16     txbuf::*,
17     vhu_vsock::{
18         Error, Result, VSOCK_FLAGS_SHUTDOWN_RCV, VSOCK_FLAGS_SHUTDOWN_SEND,
19         VSOCK_OP_CREDIT_REQUEST, VSOCK_OP_CREDIT_UPDATE, VSOCK_OP_REQUEST, VSOCK_OP_RESPONSE,
20         VSOCK_OP_RST, VSOCK_OP_RW, VSOCK_OP_SHUTDOWN, VSOCK_TYPE_STREAM,
21     },
22     vhu_vsock_thread::VhostUserVsockThread,
23 };
24 
25 #[derive(Debug)]
26 pub(crate) struct VsockConnection<S> {
27     /// Host-side stream corresponding to this vsock connection.
28     pub stream: S,
29     /// Specifies if the stream is connected to a listener on the host.
30     pub connect: bool,
31     /// Port at which a guest application is listening to.
32     pub peer_port: u32,
33     /// Queue holding pending rx operations per connection.
34     pub rx_queue: RxQueue,
35     /// CID of the host.
36     local_cid: u64,
37     /// Port on the host at which a host-side application listens to.
38     pub local_port: u32,
39     /// CID of the guest.
40     pub guest_cid: u64,
41     /// Total number of bytes written to stream from tx buffer.
42     pub fwd_cnt: Wrapping<u32>,
43     /// Total number of bytes previously forwarded to stream.
44     last_fwd_cnt: Wrapping<u32>,
45     /// Size of buffer the guest has allocated for this connection.
46     peer_buf_alloc: u32,
47     /// Number of bytes the peer has forwarded to a connection.
48     peer_fwd_cnt: Wrapping<u32>,
49     /// The total number of bytes sent to the guest vsock driver.
50     rx_cnt: Wrapping<u32>,
51     /// epoll fd to which this connection's stream has to be added.
52     pub epoll_fd: RawFd,
53     /// Local tx buffer.
54     pub tx_buf: LocalTxBuf,
55     /// Local tx buffer size
56     tx_buffer_size: u32,
57 }
58 
59 impl<S: AsRawFd + Read + Write> VsockConnection<S> {
60     /// Create a new vsock connection object for locally i.e host-side
61     /// inititated connections.
new_local_init( stream: S, local_cid: u64, local_port: u32, guest_cid: u64, guest_port: u32, epoll_fd: RawFd, tx_buffer_size: u32, ) -> Self62     pub fn new_local_init(
63         stream: S,
64         local_cid: u64,
65         local_port: u32,
66         guest_cid: u64,
67         guest_port: u32,
68         epoll_fd: RawFd,
69         tx_buffer_size: u32,
70     ) -> Self {
71         Self {
72             stream,
73             connect: false,
74             peer_port: guest_port,
75             rx_queue: RxQueue::new(),
76             local_cid,
77             local_port,
78             guest_cid,
79             fwd_cnt: Wrapping(0),
80             last_fwd_cnt: Wrapping(0),
81             peer_buf_alloc: 0,
82             peer_fwd_cnt: Wrapping(0),
83             rx_cnt: Wrapping(0),
84             epoll_fd,
85             tx_buf: LocalTxBuf::new(tx_buffer_size),
86             tx_buffer_size,
87         }
88     }
89 
90     /// Create a new vsock connection object for connections initiated by
91     /// an application running in the guest.
92     #[allow(clippy::too_many_arguments)]
new_peer_init( stream: S, local_cid: u64, local_port: u32, guest_cid: u64, guest_port: u32, epoll_fd: RawFd, peer_buf_alloc: u32, tx_buffer_size: u32, ) -> Self93     pub fn new_peer_init(
94         stream: S,
95         local_cid: u64,
96         local_port: u32,
97         guest_cid: u64,
98         guest_port: u32,
99         epoll_fd: RawFd,
100         peer_buf_alloc: u32,
101         tx_buffer_size: u32,
102     ) -> Self {
103         let mut rx_queue = RxQueue::new();
104         rx_queue.enqueue(RxOps::Response);
105         Self {
106             stream,
107             connect: false,
108             peer_port: guest_port,
109             rx_queue,
110             local_cid,
111             local_port,
112             guest_cid,
113             fwd_cnt: Wrapping(0),
114             last_fwd_cnt: Wrapping(0),
115             peer_buf_alloc,
116             peer_fwd_cnt: Wrapping(0),
117             rx_cnt: Wrapping(0),
118             epoll_fd,
119             tx_buf: LocalTxBuf::new(tx_buffer_size),
120             tx_buffer_size,
121         }
122     }
123 
124     /// Set the peer port to the guest side application's port.
set_peer_port(&mut self, peer_port: u32)125     pub fn set_peer_port(&mut self, peer_port: u32) {
126         self.peer_port = peer_port;
127     }
128 
129     /// Process a vsock packet that is meant for this connection.
130     /// Forward data to the host-side application if the vsock packet
131     /// contains a RW operation.
recv_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()>132     pub fn recv_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> {
133         // Initialize all fields in the packet header
134         self.init_pkt(pkt);
135 
136         match self.rx_queue.dequeue() {
137             Some(RxOps::Request) => {
138                 // Send a connection request to the guest-side application
139                 pkt.set_op(VSOCK_OP_REQUEST);
140                 Ok(())
141             }
142             Some(RxOps::Rw) => {
143                 if !self.connect {
144                     // There is no host-side application listening for this
145                     // packet, hence send back an RST.
146                     pkt.set_op(VSOCK_OP_RST);
147                     return Ok(());
148                 }
149 
150                 // Check if peer has space for receiving data
151                 if self.need_credit_update_from_peer() {
152                     self.last_fwd_cnt = self.fwd_cnt;
153                     pkt.set_op(VSOCK_OP_CREDIT_REQUEST);
154                     return Ok(());
155                 }
156                 let buf = pkt.data_slice().ok_or(Error::PktBufMissing)?;
157 
158                 // Perform a credit check to find the maximum read size. The read
159                 // data must fit inside a packet buffer and be within peer's
160                 // available buffer space
161                 let max_read_len = std::cmp::min(buf.len(), self.peer_avail_credit());
162 
163                 // Read data from the stream directly into the buffer
164                 if let Ok(read_cnt) = buf.read_from(0, &mut self.stream, max_read_len) {
165                     if read_cnt == 0 {
166                         // If no data was read then the stream was closed down unexpectedly.
167                         // Send a shutdown packet to the guest-side application.
168                         pkt.set_op(VSOCK_OP_SHUTDOWN)
169                             .set_flag(VSOCK_FLAGS_SHUTDOWN_RCV)
170                             .set_flag(VSOCK_FLAGS_SHUTDOWN_SEND);
171                     } else {
172                         // If data was read, then set the length field in the packet header
173                         // to the amount of data that was read.
174                         pkt.set_op(VSOCK_OP_RW).set_len(read_cnt as u32);
175 
176                         // Re-register the stream file descriptor for read and write events
177                         if VhostUserVsockThread::epoll_modify(
178                             self.epoll_fd,
179                             self.stream.as_raw_fd(),
180                             epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
181                         )
182                         .is_err()
183                         {
184                             if let Err(e) = VhostUserVsockThread::epoll_register(
185                                 self.epoll_fd,
186                                 self.stream.as_raw_fd(),
187                                 epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
188                             ) {
189                                 // TODO: let's move this logic out of this func, and handle it properly
190                                 error!("epoll_register failed: {:?}, but proceed further.", e);
191                             }
192                         };
193                     }
194 
195                     // Update the rx_cnt with the amount of data in the vsock packet.
196                     self.rx_cnt += Wrapping(pkt.len());
197                     self.last_fwd_cnt = self.fwd_cnt;
198                 }
199                 Ok(())
200             }
201             Some(RxOps::Response) => {
202                 // A response has been received to a newly initiated host-side connection
203                 self.connect = true;
204                 pkt.set_op(VSOCK_OP_RESPONSE);
205                 Ok(())
206             }
207             Some(RxOps::CreditUpdate) => {
208                 // Request credit update from the guest.
209                 if !self.rx_queue.pending_rx() {
210                     // Waste an rx buffer if no rx is pending
211                     pkt.set_op(VSOCK_OP_CREDIT_UPDATE);
212                     self.last_fwd_cnt = self.fwd_cnt;
213                 }
214                 Ok(())
215             }
216             _ => Err(Error::NoRequestRx),
217         }
218     }
219 
220     /// Deliver a guest generated packet to this connection.
221     ///
222     /// Returns:
223     /// - always `Ok(())` to indicate that the packet has been consumed
send_pkt<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) -> Result<()>224     pub fn send_pkt<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) -> Result<()> {
225         // Update peer credit information
226         self.peer_buf_alloc = pkt.buf_alloc();
227         self.peer_fwd_cnt = Wrapping(pkt.fwd_cnt());
228 
229         match pkt.op() {
230             VSOCK_OP_RESPONSE => {
231                 // Confirmation for a host initiated connection
232                 // TODO: Handle stream write error in a better manner
233                 let response = format!("OK {}\n", self.peer_port);
234                 self.stream.write_all(response.as_bytes()).unwrap();
235                 self.connect = true;
236             }
237             VSOCK_OP_RW => {
238                 // Data has to be written to the host-side stream
239                 match pkt.data_slice() {
240                     None => {
241                         info!(
242                             "Dropping empty packet from guest (lp={}, pp={})",
243                             self.local_port, self.peer_port
244                         );
245                         return Ok(());
246                     }
247                     Some(buf) => {
248                         if let Err(err) = self.send_bytes(buf) {
249                             // TODO: Terminate this connection
250                             dbg!("err:{:?}", err);
251                             return Ok(());
252                         }
253                     }
254                 }
255             }
256             VSOCK_OP_CREDIT_UPDATE => {
257                 // Already updated the credit
258 
259                 // Re-register the stream file descriptor for read and write events
260                 if VhostUserVsockThread::epoll_modify(
261                     self.epoll_fd,
262                     self.stream.as_raw_fd(),
263                     epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
264                 )
265                 .is_err()
266                 {
267                     if let Err(e) = VhostUserVsockThread::epoll_register(
268                         self.epoll_fd,
269                         self.stream.as_raw_fd(),
270                         epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
271                     ) {
272                         // TODO: let's move this logic out of this func, and handle it properly
273                         error!("epoll_register failed: {:?}, but proceed further.", e);
274                     }
275                 };
276             }
277             VSOCK_OP_CREDIT_REQUEST => {
278                 // Send back this connection's credit information
279                 self.rx_queue.enqueue(RxOps::CreditUpdate);
280             }
281             VSOCK_OP_SHUTDOWN => {
282                 // Shutdown this connection
283                 let recv_off = pkt.flags() & VSOCK_FLAGS_SHUTDOWN_RCV != 0;
284                 let send_off = pkt.flags() & VSOCK_FLAGS_SHUTDOWN_SEND != 0;
285 
286                 if recv_off && send_off && self.tx_buf.is_empty() {
287                     self.rx_queue.enqueue(RxOps::Reset);
288                 }
289             }
290             _ => {}
291         }
292 
293         Ok(())
294     }
295 
296     /// Write data to the host-side stream.
297     ///
298     /// Returns:
299     /// - Ok(cnt) where cnt is the number of bytes written to the stream
300     /// - Err(Error::UnixWrite) if there was an error writing to the stream
send_bytes<B: BitmapSlice>(&mut self, buf: &VolatileSlice<B>) -> Result<()>301     fn send_bytes<B: BitmapSlice>(&mut self, buf: &VolatileSlice<B>) -> Result<()> {
302         if !self.tx_buf.is_empty() {
303             // Data is already present in the buffer and the backend
304             // is waiting for a EPOLLOUT event to flush it
305             return self.tx_buf.push(buf);
306         }
307 
308         // Write data to the stream
309         let written_count = match buf.write_to(0, &mut self.stream, buf.len()) {
310             Ok(cnt) => cnt,
311             Err(vm_memory::VolatileMemoryError::IOError(e)) => {
312                 if e.kind() == ErrorKind::WouldBlock {
313                     0
314                 } else {
315                     dbg!("send_bytes error: {:?}", e);
316                     return Err(Error::UnixWrite);
317                 }
318             }
319             Err(e) => {
320                 dbg!("send_bytes error: {:?}", e);
321                 return Err(Error::UnixWrite);
322             }
323         };
324 
325         if written_count > 0 {
326             // Increment forwarded count by number of bytes written to the stream
327             self.fwd_cnt += Wrapping(written_count as u32);
328 
329             // At what point in available credits should we send a credit update.
330             // This is set to 1/4th of the tx buffer size. If we keep it too low,
331             // we will end up sending too many credit updates. If we keep it too
332             // high, we will end up sending too few credit updates and cause stalls.
333             // Stalls are more bad than too many credit updates.
334             let free_space = self
335                 .tx_buffer_size
336                 .wrapping_sub((self.fwd_cnt - self.last_fwd_cnt).0);
337             if free_space < self.tx_buffer_size / 4 {
338                 self.rx_queue.enqueue(RxOps::CreditUpdate);
339             }
340         }
341 
342         if written_count != buf.len() {
343             return self.tx_buf.push(&buf.offset(written_count).unwrap());
344         }
345 
346         Ok(())
347     }
348 
349     /// Initialize all header fields in the vsock packet.
init_pkt<'a, 'b, B: BitmapSlice>( &self, pkt: &'a mut VsockPacket<'b, B>, ) -> &'a mut VsockPacket<'b, B>350     fn init_pkt<'a, 'b, B: BitmapSlice>(
351         &self,
352         pkt: &'a mut VsockPacket<'b, B>,
353     ) -> &'a mut VsockPacket<'b, B> {
354         // Zero out the packet header
355         pkt.set_header_from_raw(&[0u8; PKT_HEADER_SIZE]).unwrap();
356 
357         pkt.set_src_cid(self.local_cid)
358             .set_dst_cid(self.guest_cid)
359             .set_src_port(self.local_port)
360             .set_dst_port(self.peer_port)
361             .set_type(VSOCK_TYPE_STREAM)
362             .set_buf_alloc(self.tx_buffer_size)
363             .set_fwd_cnt(self.fwd_cnt.0)
364     }
365 
366     /// Get max number of bytes we can send to peer without overflowing
367     /// the peer's buffer.
peer_avail_credit(&self) -> usize368     fn peer_avail_credit(&self) -> usize {
369         (Wrapping(self.peer_buf_alloc) - (self.rx_cnt - self.peer_fwd_cnt)).0 as usize
370     }
371 
372     /// Check if we need a credit update from the peer before sending
373     /// more data to it.
need_credit_update_from_peer(&self) -> bool374     fn need_credit_update_from_peer(&self) -> bool {
375         self.peer_avail_credit() == 0
376     }
377 }
378 
379 #[cfg(test)]
380 mod tests {
381     use byteorder::{ByteOrder, LittleEndian};
382 
383     use super::*;
384     use crate::vhu_vsock::{VSOCK_HOST_CID, VSOCK_OP_RW, VSOCK_TYPE_STREAM};
385     use std::io::Result as IoResult;
386     use std::ops::Deref;
387     use virtio_bindings::bindings::virtio_ring::{VRING_DESC_F_NEXT, VRING_DESC_F_WRITE};
388     use virtio_queue::{mock::MockSplitQueue, Descriptor, DescriptorChain, Queue, QueueOwnedT};
389     use vm_memory::{
390         Address, Bytes, GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryLoadGuard,
391         GuestMemoryMmap,
392     };
393 
394     const CONN_TX_BUF_SIZE: u32 = 64 * 1024;
395 
396     struct HeadParams {
397         head_len: usize,
398         data_len: u32,
399     }
400 
401     impl HeadParams {
new(head_len: usize, data_len: u32) -> Self402         fn new(head_len: usize, data_len: u32) -> Self {
403             Self { head_len, data_len }
404         }
construct_head(&self) -> Vec<u8>405         fn construct_head(&self) -> Vec<u8> {
406             let mut header = vec![0_u8; self.head_len];
407             if self.head_len == PKT_HEADER_SIZE {
408                 // Offset into the header for data length
409                 const HDROFF_LEN: usize = 24;
410                 LittleEndian::write_u32(&mut header[HDROFF_LEN..], self.data_len);
411             }
412             header
413         }
414     }
415 
prepare_desc_chain_vsock( write_only: bool, head_params: &HeadParams, data_chain_len: u16, head_data_len: u32, ) -> ( GuestMemoryAtomic<GuestMemoryMmap>, DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>, )416     fn prepare_desc_chain_vsock(
417         write_only: bool,
418         head_params: &HeadParams,
419         data_chain_len: u16,
420         head_data_len: u32,
421     ) -> (
422         GuestMemoryAtomic<GuestMemoryMmap>,
423         DescriptorChain<GuestMemoryLoadGuard<GuestMemoryMmap>>,
424     ) {
425         let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x1000)]).unwrap();
426         let virt_queue = MockSplitQueue::new(&mem, 16);
427         let mut next_addr = virt_queue.desc_table().total_size() + 0x100;
428         let mut flags = 0;
429 
430         if write_only {
431             flags |= VRING_DESC_F_WRITE;
432         }
433 
434         let mut head_flags = if data_chain_len > 0 {
435             flags | VRING_DESC_F_NEXT
436         } else {
437             flags
438         };
439 
440         // vsock packet header
441         // let header = vec![0 as u8; head_params.head_len];
442         let header = head_params.construct_head();
443         let head_desc =
444             Descriptor::new(next_addr, head_params.head_len as u32, head_flags as u16, 1);
445         mem.write(&header, head_desc.addr()).unwrap();
446         assert!(virt_queue.desc_table().store(0, head_desc).is_ok());
447         next_addr += head_params.head_len as u64;
448 
449         // Put the descriptor index 0 in the first available ring position.
450         mem.write_obj(0u16, virt_queue.avail_addr().unchecked_add(4))
451             .unwrap();
452 
453         // Set `avail_idx` to 1.
454         mem.write_obj(1u16, virt_queue.avail_addr().unchecked_add(2))
455             .unwrap();
456 
457         // chain len excludes the head
458         for i in 0..(data_chain_len) {
459             // last descr in chain
460             if i == data_chain_len - 1 {
461                 head_flags &= !VRING_DESC_F_NEXT;
462             }
463             // vsock data
464             let data = vec![0_u8; head_data_len as usize];
465             let data_desc = Descriptor::new(next_addr, data.len() as u32, head_flags as u16, i + 2);
466             mem.write(&data, data_desc.addr()).unwrap();
467             assert!(virt_queue.desc_table().store(i + 1, data_desc).is_ok());
468             next_addr += head_data_len as u64;
469         }
470 
471         // Create descriptor chain from pre-filled memory
472         (
473             GuestMemoryAtomic::new(mem.clone()),
474             virt_queue
475                 .create_queue::<Queue>()
476                 .unwrap()
477                 .iter(GuestMemoryAtomic::new(mem.clone()).memory())
478                 .unwrap()
479                 .next()
480                 .unwrap(),
481         )
482     }
483 
484     struct VsockDummySocket {
485         data: Vec<u8>,
486     }
487 
488     impl VsockDummySocket {
new() -> Self489         fn new() -> Self {
490             Self { data: Vec::new() }
491         }
492     }
493 
494     impl Write for VsockDummySocket {
write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error>495         fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
496             self.data.clear();
497             self.data.extend_from_slice(buf);
498 
499             Ok(buf.len())
500         }
flush(&mut self) -> IoResult<()>501         fn flush(&mut self) -> IoResult<()> {
502             Ok(())
503         }
504     }
505 
506     impl Read for VsockDummySocket {
read(&mut self, buf: &mut [u8]) -> IoResult<usize>507         fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
508             buf[..self.data.len()].copy_from_slice(&self.data);
509             Ok(self.data.len())
510         }
511     }
512 
513     impl AsRawFd for VsockDummySocket {
as_raw_fd(&self) -> RawFd514         fn as_raw_fd(&self) -> RawFd {
515             -1
516         }
517     }
518 
519     #[test]
test_vsock_conn_init()520     fn test_vsock_conn_init() {
521         // new locally inititated connection
522         let dummy_file = VsockDummySocket::new();
523         let mut conn_local = VsockConnection::new_local_init(
524             dummy_file,
525             VSOCK_HOST_CID,
526             5000,
527             3,
528             5001,
529             -1,
530             CONN_TX_BUF_SIZE,
531         );
532 
533         assert!(!conn_local.connect);
534         assert_eq!(conn_local.peer_port, 5001);
535         assert_eq!(conn_local.rx_queue, RxQueue::new());
536         assert_eq!(conn_local.local_cid, VSOCK_HOST_CID);
537         assert_eq!(conn_local.local_port, 5000);
538         assert_eq!(conn_local.guest_cid, 3);
539 
540         // set peer port
541         conn_local.set_peer_port(5002);
542         assert_eq!(conn_local.peer_port, 5002);
543 
544         // New connection initiated by the peer/guest
545         let dummy_file = VsockDummySocket::new();
546         let mut conn_peer = VsockConnection::new_peer_init(
547             dummy_file,
548             VSOCK_HOST_CID,
549             5000,
550             3,
551             5001,
552             -1,
553             65536,
554             CONN_TX_BUF_SIZE,
555         );
556 
557         assert!(!conn_peer.connect);
558         assert_eq!(conn_peer.peer_port, 5001);
559         assert_eq!(conn_peer.rx_queue.dequeue().unwrap(), RxOps::Response);
560         assert!(!conn_peer.rx_queue.pending_rx());
561         assert_eq!(conn_peer.local_cid, VSOCK_HOST_CID);
562         assert_eq!(conn_peer.local_port, 5000);
563         assert_eq!(conn_peer.guest_cid, 3);
564         assert_eq!(conn_peer.peer_buf_alloc, 65536);
565     }
566 
567     #[test]
test_vsock_conn_credit()568     fn test_vsock_conn_credit() {
569         // new locally inititated connection
570         let dummy_file = VsockDummySocket::new();
571         let mut conn_local = VsockConnection::new_local_init(
572             dummy_file,
573             VSOCK_HOST_CID,
574             5000,
575             3,
576             5001,
577             -1,
578             CONN_TX_BUF_SIZE,
579         );
580 
581         assert_eq!(conn_local.peer_avail_credit(), 0);
582         assert!(conn_local.need_credit_update_from_peer());
583 
584         conn_local.peer_buf_alloc = 65536;
585         assert_eq!(conn_local.peer_avail_credit(), 65536);
586         assert!(!conn_local.need_credit_update_from_peer());
587 
588         conn_local.rx_cnt = Wrapping(32768);
589         assert_eq!(conn_local.peer_avail_credit(), 32768);
590         assert!(!conn_local.need_credit_update_from_peer());
591 
592         conn_local.rx_cnt = Wrapping(65536);
593         assert_eq!(conn_local.peer_avail_credit(), 0);
594         assert!(conn_local.need_credit_update_from_peer());
595     }
596 
597     #[test]
test_vsock_conn_init_pkt()598     fn test_vsock_conn_init_pkt() {
599         // parameters for packet head construction
600         let head_params = HeadParams::new(PKT_HEADER_SIZE, 10);
601 
602         // new locally inititated connection
603         let dummy_file = VsockDummySocket::new();
604         let conn_local = VsockConnection::new_local_init(
605             dummy_file,
606             VSOCK_HOST_CID,
607             5000,
608             3,
609             5001,
610             -1,
611             CONN_TX_BUF_SIZE,
612         );
613 
614         // write only descriptor chain
615         let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 2, 10);
616         let mem = mem.memory();
617         let mut pkt =
618             VsockPacket::from_rx_virtq_chain(mem.deref(), &mut descr_chain, CONN_TX_BUF_SIZE)
619                 .unwrap();
620 
621         // initialize a vsock packet for the guest
622         conn_local.init_pkt(&mut pkt);
623 
624         assert_eq!(pkt.src_cid(), VSOCK_HOST_CID);
625         assert_eq!(pkt.dst_cid(), 3);
626         assert_eq!(pkt.src_port(), 5000);
627         assert_eq!(pkt.dst_port(), 5001);
628         assert_eq!(pkt.type_(), VSOCK_TYPE_STREAM);
629         assert_eq!(pkt.buf_alloc(), CONN_TX_BUF_SIZE);
630         assert_eq!(pkt.fwd_cnt(), 0);
631     }
632 
633     #[test]
test_vsock_conn_recv_pkt()634     fn test_vsock_conn_recv_pkt() {
635         // parameters for packet head construction
636         let head_params = HeadParams::new(PKT_HEADER_SIZE, 5);
637 
638         // new locally inititated connection
639         let dummy_file = VsockDummySocket::new();
640         let mut conn_local = VsockConnection::new_local_init(
641             dummy_file,
642             VSOCK_HOST_CID,
643             5000,
644             3,
645             5001,
646             -1,
647             CONN_TX_BUF_SIZE,
648         );
649 
650         // write only descriptor chain
651         let (mem, mut descr_chain) = prepare_desc_chain_vsock(true, &head_params, 1, 5);
652         let mem = mem.memory();
653         let mut pkt =
654             VsockPacket::from_rx_virtq_chain(mem.deref(), &mut descr_chain, CONN_TX_BUF_SIZE)
655                 .unwrap();
656 
657         // VSOCK_OP_REQUEST: new local conn request
658         conn_local.rx_queue.enqueue(RxOps::Request);
659         let op_req = conn_local.recv_pkt(&mut pkt);
660         assert!(op_req.is_ok());
661         assert!(!conn_local.rx_queue.pending_rx());
662         assert_eq!(pkt.op(), VSOCK_OP_REQUEST);
663 
664         // VSOCK_OP_RST: reset if connection not established
665         conn_local.rx_queue.enqueue(RxOps::Rw);
666         let op_rst = conn_local.recv_pkt(&mut pkt);
667         assert!(op_rst.is_ok());
668         assert!(!conn_local.rx_queue.pending_rx());
669         assert_eq!(pkt.op(), VSOCK_OP_RST);
670 
671         // VSOCK_OP_CREDIT_UPDATE: need credit update from peer/guest
672         conn_local.connect = true;
673         conn_local.rx_queue.enqueue(RxOps::Rw);
674         conn_local.fwd_cnt = Wrapping(1024);
675         let op_credit_update = conn_local.recv_pkt(&mut pkt);
676         assert!(op_credit_update.is_ok());
677         assert!(!conn_local.rx_queue.pending_rx());
678         assert_eq!(pkt.op(), VSOCK_OP_CREDIT_REQUEST);
679         assert_eq!(conn_local.last_fwd_cnt, Wrapping(1024));
680 
681         // VSOCK_OP_SHUTDOWN: zero data read from stream/file
682         conn_local.peer_buf_alloc = 65536;
683         conn_local.rx_queue.enqueue(RxOps::Rw);
684         let op_zero_read_shutdown = conn_local.recv_pkt(&mut pkt);
685         assert!(op_zero_read_shutdown.is_ok());
686         assert!(!conn_local.rx_queue.pending_rx());
687         assert_eq!(conn_local.rx_cnt, Wrapping(0));
688         assert_eq!(conn_local.last_fwd_cnt, Wrapping(1024));
689         assert_eq!(pkt.op(), VSOCK_OP_SHUTDOWN);
690         assert_eq!(
691             pkt.flags(),
692             VSOCK_FLAGS_SHUTDOWN_RCV | VSOCK_FLAGS_SHUTDOWN_SEND
693         );
694 
695         // VSOCK_OP_RW: finite data read from stream/file
696         let payload = b"hello";
697         conn_local.stream.write_all(payload).unwrap();
698         conn_local.rx_queue.enqueue(RxOps::Rw);
699         let op_zero_read = conn_local.recv_pkt(&mut pkt);
700         assert!(op_zero_read.is_ok());
701         assert_eq!(pkt.op(), VSOCK_OP_RW);
702         assert!(!conn_local.rx_queue.pending_rx());
703         assert_eq!(conn_local.rx_cnt, Wrapping(payload.len() as u32));
704         assert_eq!(conn_local.last_fwd_cnt, Wrapping(1024));
705         assert_eq!(pkt.len(), 5);
706         let buf = &mut [0u8; 5];
707         assert!(pkt.data_slice().unwrap().read_slice(buf, 0).is_ok());
708         assert_eq!(buf, b"hello");
709 
710         // VSOCK_OP_RESPONSE: response from a locally initiated connection
711         conn_local.rx_queue.enqueue(RxOps::Response);
712         let op_response = conn_local.recv_pkt(&mut pkt);
713         assert!(op_response.is_ok());
714         assert!(!conn_local.rx_queue.pending_rx());
715         assert_eq!(pkt.op(), VSOCK_OP_RESPONSE);
716         assert!(conn_local.connect);
717 
718         // VSOCK_OP_CREDIT_UPDATE: guest needs credit update
719         conn_local.rx_queue.enqueue(RxOps::CreditUpdate);
720         let op_credit_update = conn_local.recv_pkt(&mut pkt);
721         assert!(!conn_local.rx_queue.pending_rx());
722         assert!(op_credit_update.is_ok());
723         assert_eq!(pkt.op(), VSOCK_OP_CREDIT_UPDATE);
724         assert_eq!(conn_local.last_fwd_cnt, Wrapping(1024));
725 
726         // non-existent request
727         let op_error = conn_local.recv_pkt(&mut pkt);
728         assert!(op_error.is_err());
729     }
730 
731     #[test]
test_vsock_conn_send_pkt()732     fn test_vsock_conn_send_pkt() {
733         // parameters for packet head construction
734         let head_params = HeadParams::new(PKT_HEADER_SIZE, 5);
735 
736         // new locally inititated connection
737         let dummy_file = VsockDummySocket::new();
738         let mut conn_local = VsockConnection::new_local_init(
739             dummy_file,
740             VSOCK_HOST_CID,
741             5000,
742             3,
743             5001,
744             -1,
745             CONN_TX_BUF_SIZE,
746         );
747 
748         // write only descriptor chain
749         let (mem, mut descr_chain) = prepare_desc_chain_vsock(false, &head_params, 1, 5);
750         let mem = mem.memory();
751         let mut pkt =
752             VsockPacket::from_tx_virtq_chain(mem.deref(), &mut descr_chain, CONN_TX_BUF_SIZE)
753                 .unwrap();
754 
755         // peer credit information
756         pkt.set_buf_alloc(65536).set_fwd_cnt(1024);
757 
758         // check if peer credit information is updated currently
759         let credit_check = conn_local.send_pkt(&pkt);
760         assert!(credit_check.is_ok());
761         assert_eq!(conn_local.peer_buf_alloc, 65536);
762         assert_eq!(conn_local.peer_fwd_cnt, Wrapping(1024));
763 
764         // VSOCK_OP_RESPONSE
765         pkt.set_op(VSOCK_OP_RESPONSE);
766         let peer_response = conn_local.send_pkt(&pkt);
767         assert!(peer_response.is_ok());
768         assert!(conn_local.connect);
769         let mut resp_buf = vec![0; 8];
770         conn_local.stream.read_exact(&mut resp_buf).unwrap();
771         assert_eq!(resp_buf, b"OK 5001\n");
772 
773         // VSOCK_OP_RW
774         pkt.set_op(VSOCK_OP_RW);
775         let buf = b"hello";
776         assert!(pkt.data_slice().unwrap().write_slice(buf, 0).is_ok());
777         let rw_response = conn_local.send_pkt(&pkt);
778         assert!(rw_response.is_ok());
779         let mut resp_buf = vec![0; 5];
780         conn_local.stream.read_exact(&mut resp_buf).unwrap();
781         assert_eq!(resp_buf, b"hello");
782 
783         // VSOCK_OP_CREDIT_REQUEST
784         pkt.set_op(VSOCK_OP_CREDIT_REQUEST);
785         let credit_response = conn_local.send_pkt(&pkt);
786         assert!(credit_response.is_ok());
787         assert_eq!(conn_local.rx_queue.peek().unwrap(), RxOps::CreditUpdate);
788 
789         // VSOCK_OP_SHUTDOWN
790         pkt.set_op(VSOCK_OP_SHUTDOWN);
791         pkt.set_flags(VSOCK_FLAGS_SHUTDOWN_RCV | VSOCK_FLAGS_SHUTDOWN_SEND);
792         let shutdown_response = conn_local.send_pkt(&pkt);
793         assert!(shutdown_response.is_ok());
794         assert!(conn_local.rx_queue.contains(RxOps::Reset.bitmask()));
795     }
796 }
797