xref: /aosp_15_r20/external/wmediumd/wmediumd/lib/vhost.c (revision 621120a22a0cd8ba80b131fe8bcb37c86ff453e3)
1 /*
2  * Copyright (C) 2019 - 2020 Intel Corporation
3  *
4  * SPDX-License-Identifier: BSD-3-Clause
5  */
6 #include <stdlib.h>
7 #include <usfstl/list.h>
8 #include <usfstl/loop.h>
9 #include <usfstl/uds.h>
10 #include <sys/socket.h>
11 #include <sys/mman.h>
12 #include <sys/un.h>
13 #include <stdlib.h>
14 #include <errno.h>
15 #include <usfstl/vhost.h>
16 #include <linux/virtio_ring.h>
17 #include <linux/virtio_config.h>
18 #include <endian.h>
19 
20 /* copied from uapi */
21 #define VIRTIO_F_VERSION_1		32
22 
23 #define MAX_REGIONS 8
24 #define SG_STACK_PREALLOC 5
25 
26 struct usfstl_vhost_user_dev_int {
27 	struct usfstl_list fds;
28 	struct usfstl_job irq_job;
29 
30 	struct usfstl_loop_entry entry;
31 
32 	struct usfstl_vhost_user_dev ext;
33 
34 	unsigned int n_regions;
35 	struct vhost_user_region regions[MAX_REGIONS];
36 	int region_fds[MAX_REGIONS];
37 	void *region_vaddr[MAX_REGIONS];
38 
39 	int req_fd;
40 
41 	struct {
42 		struct usfstl_loop_entry entry;
43 		bool enabled;
44 		bool triggered;
45 		struct vring virtq;
46 		int call_fd;
47 		uint16_t last_avail_idx;
48 	} virtqs[];
49 };
50 
51 #define CONV(bits)							\
52 static inline uint##bits##_t __attribute__((used))			\
53 cpu_to_virtio##bits(struct usfstl_vhost_user_dev_int *dev,		\
54 		    uint##bits##_t v)					\
55 {									\
56 	if (dev->ext.features & (1ULL << VIRTIO_F_VERSION_1))		\
57 		return htole##bits(v);					\
58 	return v;							\
59 }									\
60 static inline uint##bits##_t __attribute__((used))			\
61 virtio_to_cpu##bits(struct usfstl_vhost_user_dev_int *dev,		\
62 		    uint##bits##_t v)					\
63 {									\
64 	if (dev->ext.features & (1ULL << VIRTIO_F_VERSION_1))		\
65 		return le##bits##toh(v);				\
66 	return v;							\
67 }
68 
69 CONV(16)
70 CONV(32)
71 CONV(64)
72 
73 static struct usfstl_vhost_user_buf *
usfstl_vhost_user_get_virtq_buf(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq_idx,struct usfstl_vhost_user_buf * fixed)74 usfstl_vhost_user_get_virtq_buf(struct usfstl_vhost_user_dev_int *dev,
75 				unsigned int virtq_idx,
76 				struct usfstl_vhost_user_buf *fixed)
77 {
78 	struct usfstl_vhost_user_buf *buf = fixed;
79 	struct vring *virtq = &dev->virtqs[virtq_idx].virtq;
80 	uint16_t avail_idx = virtio_to_cpu16(dev, virtq->avail->idx);
81 	uint16_t idx, desc_idx;
82 	struct vring_desc *desc;
83 	unsigned int n_in = 0, n_out = 0;
84 	bool more;
85 
86 	if (avail_idx == dev->virtqs[virtq_idx].last_avail_idx)
87 		return NULL;
88 
89 	/* ensure we read the descriptor after checking the index */
90 	__sync_synchronize();
91 
92 	idx = dev->virtqs[virtq_idx].last_avail_idx++;
93 	idx %= virtq->num;
94 	desc_idx = virtio_to_cpu16(dev, virtq->avail->ring[idx]);
95 	USFSTL_ASSERT(desc_idx < virtq->num);
96 
97 	desc = &virtq->desc[desc_idx];
98 	do {
99 		more = virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_NEXT;
100 
101 		if (virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_WRITE)
102 			n_in++;
103 		else
104 			n_out++;
105 		desc = &virtq->desc[virtio_to_cpu16(dev, desc->next)];
106 	} while (more);
107 
108 	if (n_in > fixed->n_in_sg || n_out > fixed->n_out_sg) {
109 		size_t sz = sizeof(*buf);
110 		struct iovec *vec;
111 
112 		sz += (n_in + n_out) * sizeof(*vec);
113 
114 		buf = calloc(1, sz);
115 		if (!buf)
116 			return NULL;
117 
118 		vec = (void *)(buf + 1);
119 		buf->in_sg = vec;
120 		buf->out_sg = vec + n_in;
121 		buf->allocated = true;
122 	}
123 
124 	buf->n_in_sg = 0;
125 	buf->n_out_sg = 0;
126 	buf->idx = desc_idx;
127 
128 	desc = &virtq->desc[desc_idx];
129 	do {
130 		struct iovec *vec;
131 		uint64_t addr;
132 
133 		more = virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_NEXT;
134 
135 		if (virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_WRITE) {
136 			vec = &buf->in_sg[buf->n_in_sg];
137 			buf->n_in_sg++;
138 		} else {
139 			vec = &buf->out_sg[buf->n_out_sg];
140 			buf->n_out_sg++;
141 		}
142 
143 		addr = virtio_to_cpu64(dev, desc->addr);
144 		vec->iov_base = usfstl_vhost_phys_to_va(&dev->ext, addr);
145 		vec->iov_len = virtio_to_cpu32(dev, desc->len);
146 
147 		desc = &virtq->desc[virtio_to_cpu16(dev, desc->next)];
148 	} while (more);
149 
150 	return buf;
151 }
152 
usfstl_vhost_user_free_buf(struct usfstl_vhost_user_buf * buf)153 static void usfstl_vhost_user_free_buf(struct usfstl_vhost_user_buf *buf)
154 {
155 	if (buf->allocated)
156 		free(buf);
157 }
158 
usfstl_vhost_user_readable_handler(struct usfstl_loop_entry * entry)159 static void usfstl_vhost_user_readable_handler(struct usfstl_loop_entry *entry)
160 {
161 	usfstl_loop_unregister(entry);
162 	entry->fd = -1;
163 }
164 
usfstl_vhost_user_read_msg(int fd,struct msghdr * msghdr)165 static int usfstl_vhost_user_read_msg(int fd, struct msghdr *msghdr)
166 {
167 	struct iovec msg_iov;
168 	struct msghdr hdr2 = {
169 		.msg_iov = &msg_iov,
170 		.msg_iovlen = 1,
171 		.msg_control = msghdr->msg_control,
172 		.msg_controllen = msghdr->msg_controllen,
173 	};
174 	struct vhost_user_msg_hdr *hdr;
175 	size_t i;
176 	size_t maxlen = 0;
177 	ssize_t len;
178 	ssize_t prev_datalen;
179 	size_t prev_iovlen;
180 
181 	USFSTL_ASSERT(msghdr->msg_iovlen >= 1);
182 	USFSTL_ASSERT(msghdr->msg_iov[0].iov_len >= sizeof(*hdr));
183 
184 	hdr = msghdr->msg_iov[0].iov_base;
185 	msg_iov.iov_base = hdr;
186 	msg_iov.iov_len = sizeof(*hdr);
187 
188 	len = recvmsg(fd, &hdr2, 0);
189 	if (len < 0)
190 		return -errno;
191 	if (len == 0)
192 		return -ENOTCONN;
193 
194 	for (i = 0; i < msghdr->msg_iovlen; i++)
195 		maxlen += msghdr->msg_iov[i].iov_len;
196 	maxlen -= sizeof(*hdr);
197 
198 	USFSTL_ASSERT_EQ((int)len, (int)sizeof(*hdr), "%d");
199 	USFSTL_ASSERT(hdr->size <= maxlen);
200 
201 	if (!hdr->size)
202 		return 0;
203 
204 	prev_iovlen = msghdr->msg_iovlen;
205 	msghdr->msg_iovlen = 1;
206 
207 	msghdr->msg_control = NULL;
208 	msghdr->msg_controllen = 0;
209 	msghdr->msg_iov[0].iov_base += sizeof(*hdr);
210 	prev_datalen = msghdr->msg_iov[0].iov_len;
211 	msghdr->msg_iov[0].iov_len = hdr->size;
212 	len = recvmsg(fd, msghdr, 0);
213 
214 	/* restore just in case the user needs it */
215 	msghdr->msg_iov[0].iov_base -= sizeof(*hdr);
216 	msghdr->msg_iov[0].iov_len = prev_datalen;
217 	msghdr->msg_control = hdr2.msg_control;
218 	msghdr->msg_controllen = hdr2.msg_controllen;
219 
220 	msghdr->msg_iovlen = prev_iovlen;
221 
222 	if (len < 0)
223 		return -errno;
224 	if (len == 0)
225 		return -ENOTCONN;
226 
227 	USFSTL_ASSERT_EQ(hdr->size, (uint32_t)len, "%u");
228 
229 	return 0;
230 }
231 
usfstl_vhost_user_send_msg(struct usfstl_vhost_user_dev_int * dev,struct vhost_user_msg * msg)232 static void usfstl_vhost_user_send_msg(struct usfstl_vhost_user_dev_int *dev,
233 				       struct vhost_user_msg *msg)
234 {
235 	size_t msgsz = sizeof(msg->hdr) + msg->hdr.size;
236 	bool ack = dev->ext.protocol_features &
237 		   (1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
238 	ssize_t written;
239 
240 	if (ack)
241 		msg->hdr.flags |= VHOST_USER_MSG_FLAGS_NEED_REPLY;
242 
243 	written = write(dev->req_fd, msg, msgsz);
244 	USFSTL_ASSERT_EQ(written, (ssize_t)msgsz, "%zd");
245 
246 	if (ack) {
247 		struct usfstl_loop_entry entry = {
248 			.fd = dev->req_fd,
249 			.priority = 0x7fffffff, // max
250 			.handler = usfstl_vhost_user_readable_handler,
251 		};
252 		struct iovec msg_iov = {
253 			.iov_base = msg,
254 			.iov_len = sizeof(*msg),
255 		};
256 		struct msghdr msghdr = {
257 			.msg_iovlen = 1,
258 			.msg_iov = &msg_iov,
259 		};
260 
261 		/*
262 		 * Wait for the fd to be readable - we may have to
263 		 * handle other simulation (time) messages while
264 		 * waiting ...
265 		 */
266 		usfstl_loop_register(&entry);
267 		while (entry.fd != -1)
268 			usfstl_loop_wait_and_handle();
269 		USFSTL_ASSERT_EQ(usfstl_vhost_user_read_msg(dev->req_fd,
270 							    &msghdr),
271 				 0, "%d");
272 	}
273 }
274 
usfstl_vhost_user_send_virtq_buf(struct usfstl_vhost_user_dev_int * dev,struct usfstl_vhost_user_buf * buf,int virtq_idx)275 static void usfstl_vhost_user_send_virtq_buf(struct usfstl_vhost_user_dev_int *dev,
276 					     struct usfstl_vhost_user_buf *buf,
277 					     int virtq_idx)
278 {
279 	struct vring *virtq = &dev->virtqs[virtq_idx].virtq;
280 	unsigned int idx, widx;
281 	int call_fd = dev->virtqs[virtq_idx].call_fd;
282 	ssize_t written;
283 	uint64_t e = 1;
284 
285 	if (dev->ext.server->ctrl)
286 		usfstl_sched_ctrl_sync_to(dev->ext.server->ctrl);
287 
288 	idx = virtio_to_cpu16(dev, virtq->used->idx);
289 	widx = idx + 1;
290 
291 	idx %= virtq->num;
292 	virtq->used->ring[idx].id = cpu_to_virtio32(dev, buf->idx);
293 	virtq->used->ring[idx].len = cpu_to_virtio32(dev, buf->written);
294 
295 	/* write buffers / used table before flush */
296 	__sync_synchronize();
297 
298 	virtq->used->idx = cpu_to_virtio16(dev, widx);
299 
300 	if (call_fd < 0 &&
301 	    dev->ext.protocol_features &
302 			(1ULL << VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
303 	    dev->ext.protocol_features &
304 			(1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
305 		struct vhost_user_msg msg = {
306 			.hdr.request = VHOST_USER_SLAVE_VRING_CALL,
307 			.hdr.flags = VHOST_USER_VERSION,
308 			.hdr.size = sizeof(msg.payload.vring_state),
309 			.payload.vring_state = {
310 				.idx = virtq_idx,
311 			},
312 		};
313 
314 		usfstl_vhost_user_send_msg(dev, &msg);
315 		return;
316 	}
317 
318 	written = write(dev->virtqs[virtq_idx].call_fd, &e, sizeof(e));
319 	USFSTL_ASSERT_EQ(written, (ssize_t)sizeof(e), "%zd");
320 }
321 
usfstl_vhost_user_handle_queue(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq_idx)322 static void usfstl_vhost_user_handle_queue(struct usfstl_vhost_user_dev_int *dev,
323 					   unsigned int virtq_idx)
324 {
325 	/* preallocate on the stack for most cases */
326 	struct iovec in_sg[SG_STACK_PREALLOC] = { };
327 	struct iovec out_sg[SG_STACK_PREALLOC] = { };
328 	struct usfstl_vhost_user_buf _buf = {
329 		.in_sg = in_sg,
330 		.n_in_sg = SG_STACK_PREALLOC,
331 		.out_sg = out_sg,
332 		.n_out_sg = SG_STACK_PREALLOC,
333 	};
334 	struct usfstl_vhost_user_buf *buf;
335 
336 	while ((buf = usfstl_vhost_user_get_virtq_buf(dev, virtq_idx, &_buf))) {
337 		dev->ext.server->ops->handle(&dev->ext, buf, virtq_idx);
338 
339 		usfstl_vhost_user_send_virtq_buf(dev, buf, virtq_idx);
340 		usfstl_vhost_user_free_buf(buf);
341 	}
342 }
343 
usfstl_vhost_user_job_callback(struct usfstl_job * job)344 static void usfstl_vhost_user_job_callback(struct usfstl_job *job)
345 {
346 	struct usfstl_vhost_user_dev_int *dev = job->data;
347 	unsigned int virtq;
348 
349 	for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
350 		if (!dev->virtqs[virtq].triggered)
351 			continue;
352 		dev->virtqs[virtq].triggered = false;
353 
354 		usfstl_vhost_user_handle_queue(dev, virtq);
355 	}
356 }
357 
usfstl_vhost_user_virtq_kick(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq)358 static void usfstl_vhost_user_virtq_kick(struct usfstl_vhost_user_dev_int *dev,
359 					 unsigned int virtq)
360 {
361 	if (!(dev->ext.server->input_queues & (1ULL << virtq)))
362 		return;
363 
364 	dev->virtqs[virtq].triggered = true;
365 
366 	if (usfstl_job_scheduled(&dev->irq_job))
367 		return;
368 
369 	if (!dev->ext.server->scheduler) {
370 		usfstl_vhost_user_job_callback(&dev->irq_job);
371 		return;
372 	}
373 
374 	if (dev->ext.server->ctrl)
375 		usfstl_sched_ctrl_sync_from(dev->ext.server->ctrl);
376 
377 	dev->irq_job.start = usfstl_sched_current_time(dev->ext.server->scheduler) +
378 			     dev->ext.server->interrupt_latency;
379 	usfstl_sched_add_job(dev->ext.server->scheduler, &dev->irq_job);
380 }
381 
usfstl_vhost_user_virtq_fdkick(struct usfstl_loop_entry * entry)382 static void usfstl_vhost_user_virtq_fdkick(struct usfstl_loop_entry *entry)
383 {
384 	struct usfstl_vhost_user_dev_int *dev = entry->data;
385 	unsigned int virtq;
386 	uint64_t v;
387 
388 	for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
389 		if (entry == &dev->virtqs[virtq].entry)
390 			break;
391 	}
392 
393 	USFSTL_ASSERT(virtq < dev->ext.server->max_queues);
394 
395 	USFSTL_ASSERT_EQ((int)read(entry->fd, &v, sizeof(v)),
396 		       (int)sizeof(v), "%d");
397 
398 	usfstl_vhost_user_virtq_kick(dev, virtq);
399 }
400 
usfstl_vhost_user_clear_mappings(struct usfstl_vhost_user_dev_int * dev)401 static void usfstl_vhost_user_clear_mappings(struct usfstl_vhost_user_dev_int *dev)
402 {
403 	unsigned int idx;
404 	for (idx = 0; idx < MAX_REGIONS; idx++) {
405 		if (dev->region_vaddr[idx]) {
406 			munmap(dev->region_vaddr[idx],
407 			       dev->regions[idx].size + dev->regions[idx].mmap_offset);
408 			dev->region_vaddr[idx] = NULL;
409 		}
410 
411 		if (dev->region_fds[idx] != -1) {
412 			close(dev->region_fds[idx]);
413 			dev->region_fds[idx] = -1;
414 		}
415 	}
416 }
417 
usfstl_vhost_user_setup_mappings(struct usfstl_vhost_user_dev_int * dev)418 static void usfstl_vhost_user_setup_mappings(struct usfstl_vhost_user_dev_int *dev)
419 {
420 	unsigned int idx;
421 
422 	for (idx = 0; idx < dev->n_regions; idx++) {
423 		USFSTL_ASSERT(!dev->region_vaddr[idx]);
424 
425 		/*
426 		 * Cannot rely on the offset being page-aligned, I think ...
427 		 * adjust for it later when we translate addresses instead.
428 		 */
429 		dev->region_vaddr[idx] = mmap(NULL,
430 					      dev->regions[idx].size +
431 					      dev->regions[idx].mmap_offset,
432 					      PROT_READ | PROT_WRITE, MAP_SHARED,
433 					      dev->region_fds[idx], 0);
434 		USFSTL_ASSERT(dev->region_vaddr[idx] != (void *)-1,
435 			      "mmap() failed (%d) for fd %d", errno, dev->region_fds[idx]);
436 	}
437 }
438 
439 static void
usfstl_vhost_user_update_virtq_kick(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq,int fd)440 usfstl_vhost_user_update_virtq_kick(struct usfstl_vhost_user_dev_int *dev,
441 				  unsigned int virtq, int fd)
442 {
443 	if (dev->virtqs[virtq].entry.fd != -1) {
444 		usfstl_loop_unregister(&dev->virtqs[virtq].entry);
445 		close(dev->virtqs[virtq].entry.fd);
446 		dev->virtqs[virtq].entry.fd = -1;
447 	}
448 
449 	if (fd != -1) {
450 		dev->virtqs[virtq].entry.fd = fd;
451 		usfstl_loop_register(&dev->virtqs[virtq].entry);
452 	}
453 }
454 
usfstl_vhost_user_dev_free(struct usfstl_vhost_user_dev_int * dev)455 static void usfstl_vhost_user_dev_free(struct usfstl_vhost_user_dev_int *dev)
456 {
457 	unsigned int virtq;
458 
459 	usfstl_loop_unregister(&dev->entry);
460 	usfstl_sched_del_job(&dev->irq_job);
461 
462 	for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
463 		usfstl_vhost_user_update_virtq_kick(dev, virtq, -1);
464 		if (dev->virtqs[virtq].call_fd != -1)
465 			close(dev->virtqs[virtq].call_fd);
466 	}
467 
468 	usfstl_vhost_user_clear_mappings(dev);
469 
470 	if (dev->req_fd != -1)
471 		close(dev->req_fd);
472 
473 	if (dev->ext.server->ops->disconnected)
474 		dev->ext.server->ops->disconnected(&dev->ext);
475 
476 	if (dev->entry.fd != -1)
477 		close(dev->entry.fd);
478 
479 	free(dev);
480 }
481 
usfstl_vhost_user_get_msg_fds(struct msghdr * msghdr,int * outfds,int max_fds)482 static void usfstl_vhost_user_get_msg_fds(struct msghdr *msghdr,
483 					  int *outfds, int max_fds)
484 {
485 	struct cmsghdr *msg;
486 	int fds;
487 
488 	for (msg = CMSG_FIRSTHDR(msghdr); msg; msg = CMSG_NXTHDR(msghdr, msg)) {
489 		if (msg->cmsg_level != SOL_SOCKET)
490 			continue;
491 		if (msg->cmsg_type != SCM_RIGHTS)
492 			continue;
493 
494 		fds = (msg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
495 		USFSTL_ASSERT(fds <= max_fds);
496 		memcpy(outfds, CMSG_DATA(msg), fds * sizeof(int));
497 		break;
498 	}
499 }
500 
usfstl_vhost_user_handle_msg(struct usfstl_loop_entry * entry)501 static void usfstl_vhost_user_handle_msg(struct usfstl_loop_entry *entry)
502 {
503 	struct usfstl_vhost_user_dev_int *dev;
504 	struct vhost_user_msg msg;
505 	uint8_t data[256]; // limits the config space size ...
506 	struct iovec msg_iov[3] = {
507 		[0] = {
508 			.iov_base = &msg.hdr,
509 			.iov_len = sizeof(msg.hdr),
510 		},
511 		[1] = {
512 			.iov_base = &msg.payload,
513 			.iov_len = sizeof(msg.payload),
514 		},
515 		[2] = {
516 			.iov_base = data,
517 			.iov_len = sizeof(data),
518 		},
519 	};
520 	uint8_t msg_control[CMSG_SPACE(sizeof(int) * MAX_REGIONS)] = { 0 };
521 	struct msghdr msghdr = {
522 		.msg_iov = msg_iov,
523 		.msg_iovlen = 3,
524 		.msg_control = msg_control,
525 		.msg_controllen = sizeof(msg_control),
526 	};
527 	ssize_t len;
528 	size_t reply_len = 0;
529 	unsigned int virtq;
530 	int fd;
531 
532 	dev = container_of(entry, struct usfstl_vhost_user_dev_int, entry);
533 
534 	if (usfstl_vhost_user_read_msg(entry->fd, &msghdr)) {
535 		usfstl_vhost_user_dev_free(dev);
536 		return;
537 	}
538 	len = msg.hdr.size;
539 
540 	USFSTL_ASSERT((msg.hdr.flags & VHOST_USER_MSG_FLAGS_VERSION) ==
541 		    VHOST_USER_VERSION);
542 
543 	switch (msg.hdr.request) {
544 	case VHOST_USER_GET_FEATURES:
545 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
546 		reply_len = sizeof(uint64_t);
547 		msg.payload.u64 = dev->ext.server->features;
548 		msg.payload.u64 |= 1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
549 		break;
550 	case VHOST_USER_SET_FEATURES:
551 		USFSTL_ASSERT_EQ(len, (ssize_t)sizeof(msg.payload.u64), "%zd");
552 		dev->ext.features = msg.payload.u64;
553 		break;
554 	case VHOST_USER_SET_OWNER:
555 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
556 		/* nothing to be done */
557 		break;
558 	case VHOST_USER_SET_MEM_TABLE:
559 		USFSTL_ASSERT(len <= (int)sizeof(msg.payload.mem_regions));
560 		USFSTL_ASSERT(msg.payload.mem_regions.n_regions <= MAX_REGIONS);
561 		usfstl_vhost_user_clear_mappings(dev);
562 		memcpy(dev->regions, msg.payload.mem_regions.regions,
563 		       msg.payload.mem_regions.n_regions *
564 		       sizeof(dev->regions[0]));
565 		dev->n_regions = msg.payload.mem_regions.n_regions;
566 		usfstl_vhost_user_get_msg_fds(&msghdr, dev->region_fds, MAX_REGIONS);
567 		usfstl_vhost_user_setup_mappings(dev);
568 		break;
569 	case VHOST_USER_SET_VRING_NUM:
570 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
571 		USFSTL_ASSERT(msg.payload.vring_state.idx <
572 			      dev->ext.server->max_queues);
573 		dev->virtqs[msg.payload.vring_state.idx].virtq.num =
574 			msg.payload.vring_state.num;
575 		break;
576 	case VHOST_USER_SET_VRING_ADDR:
577 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_addr));
578 		USFSTL_ASSERT(msg.payload.vring_addr.idx <=
579 			      dev->ext.server->max_queues);
580 		USFSTL_ASSERT_EQ(msg.payload.vring_addr.flags, (uint32_t)0, "0x%x");
581 		USFSTL_ASSERT(!dev->virtqs[msg.payload.vring_addr.idx].enabled);
582 
583 		dev->virtqs[msg.payload.vring_addr.idx].last_avail_idx = 0;
584 		dev->virtqs[msg.payload.vring_addr.idx].virtq.desc =
585 			usfstl_vhost_user_to_va(&dev->ext,
586 					      msg.payload.vring_addr.descriptor);
587 		dev->virtqs[msg.payload.vring_addr.idx].virtq.used =
588 			usfstl_vhost_user_to_va(&dev->ext,
589 					      msg.payload.vring_addr.used);
590 		dev->virtqs[msg.payload.vring_addr.idx].virtq.avail =
591 			usfstl_vhost_user_to_va(&dev->ext,
592 					      msg.payload.vring_addr.avail);
593 		USFSTL_ASSERT(dev->virtqs[msg.payload.vring_addr.idx].virtq.avail &&
594 			    dev->virtqs[msg.payload.vring_addr.idx].virtq.desc &&
595 			    dev->virtqs[msg.payload.vring_addr.idx].virtq.used);
596 		break;
597 	case VHOST_USER_SET_VRING_BASE:
598 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_desc_index_split));
599 		USFSTL_ASSERT(msg.payload.vring_desc_index_split.idx < dev->ext.server->max_queues);
600 		dev->virtqs[msg.payload.vring_desc_index_split.idx].last_avail_idx =
601 			msg.payload.vring_desc_index_split.index_in_avail_ring;
602 		break;
603 	case VHOST_USER_GET_VRING_BASE:
604 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
605 		USFSTL_ASSERT(msg.payload.vring_state.idx < dev->ext.server->max_queues);
606 		USFSTL_ASSERT(msg.payload.vring_state.num == 0);  // reserved
607 		virtq = msg.payload.vring_state.idx;
608 		// Stop the queue.
609 		usfstl_vhost_user_update_virtq_kick(dev, virtq, -1);
610 		// Build the response.
611 		msg.payload.vring_desc_index_split.idx = virtq;
612 		msg.payload.vring_desc_index_split.index_in_avail_ring =
613 			dev->virtqs[virtq].last_avail_idx;
614 		reply_len = (int)sizeof(msg.payload.vring_desc_index_split);
615 		break;
616 	case VHOST_USER_SET_VRING_KICK:
617 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
618 		virtq = msg.payload.u64 & VHOST_USER_U64_VRING_IDX_MSK;
619 		USFSTL_ASSERT(virtq <= dev->ext.server->max_queues);
620 		if (msg.payload.u64 & VHOST_USER_U64_NO_FD)
621 			fd = -1;
622 		else
623 			usfstl_vhost_user_get_msg_fds(&msghdr, &fd, 1);
624 		usfstl_vhost_user_update_virtq_kick(dev, virtq, fd);
625 		break;
626 	case VHOST_USER_SET_VRING_CALL:
627 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
628 		virtq = msg.payload.u64 & VHOST_USER_U64_VRING_IDX_MSK;
629 		USFSTL_ASSERT(virtq <= dev->ext.server->max_queues);
630 		if (dev->virtqs[virtq].call_fd != -1)
631 			close(dev->virtqs[virtq].call_fd);
632 		if (msg.payload.u64 & VHOST_USER_U64_NO_FD)
633 			dev->virtqs[virtq].call_fd = -1;
634 		else
635 			usfstl_vhost_user_get_msg_fds(&msghdr,
636 						    &dev->virtqs[virtq].call_fd,
637 						    1);
638 		break;
639 	case VHOST_USER_GET_PROTOCOL_FEATURES:
640 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
641 		reply_len = sizeof(uint64_t);
642 		msg.payload.u64 = dev->ext.server->protocol_features;
643 		if (dev->ext.server->config && dev->ext.server->config_len)
644 			msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
645 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ;
646 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD;
647 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK;
648 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_DEVICE_STATE;
649 		break;
650 	case VHOST_USER_SET_VRING_ENABLE:
651 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
652 		USFSTL_ASSERT(msg.payload.vring_state.idx <
653 			      dev->ext.server->max_queues);
654 		dev->virtqs[msg.payload.vring_state.idx].enabled =
655 			msg.payload.vring_state.num;
656 		break;
657 	case VHOST_USER_SET_PROTOCOL_FEATURES:
658 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
659 		dev->ext.protocol_features = msg.payload.u64;
660 		break;
661 	case VHOST_USER_SET_SLAVE_REQ_FD:
662 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
663 		if (dev->req_fd != -1)
664 			close(dev->req_fd);
665 		usfstl_vhost_user_get_msg_fds(&msghdr, &dev->req_fd, 1);
666 		USFSTL_ASSERT(dev->req_fd != -1);
667 		break;
668 	case VHOST_USER_GET_CONFIG:
669 		USFSTL_ASSERT(len == (int)(sizeof(msg.payload.cfg_space) +
670 					msg.payload.cfg_space.size));
671 		USFSTL_ASSERT(dev->ext.server->config && dev->ext.server->config_len);
672 		USFSTL_ASSERT(msg.payload.cfg_space.offset == 0);
673 		USFSTL_ASSERT(msg.payload.cfg_space.size <= dev->ext.server->config_len);
674 		msg.payload.cfg_space.flags = 0;
675 		msg_iov[1].iov_len = sizeof(msg.payload.cfg_space);
676 		msg_iov[2].iov_base = (void *)dev->ext.server->config;
677 		reply_len = len;
678 		break;
679 	case VHOST_USER_VRING_KICK:
680 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
681 		USFSTL_ASSERT(msg.payload.vring_state.idx <
682 			      dev->ext.server->max_queues);
683 		USFSTL_ASSERT(msg.payload.vring_state.num == 0);
684 		usfstl_vhost_user_virtq_kick(dev, msg.payload.vring_state.idx);
685 		break;
686 	case VHOST_USER_GET_SHARED_MEMORY_REGIONS:
687 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
688 		reply_len = sizeof(uint64_t);
689 		msg.payload.u64 = 0;
690 		break;
691 	case VHOST_USER_SET_DEVICE_STATE_FD: {
692 		USFSTL_ASSERT_EQ(len, sizeof(msg.payload.device_state_transfer), "%zd");
693 		USFSTL_ASSERT_EQ(msg.payload.device_state_transfer.migration_phase, /* stopped */ (uint32_t)0, "%d");
694 		// Read the attached FD.
695 		usfstl_vhost_user_get_msg_fds(&msghdr, &fd, 1);
696 		USFSTL_ASSERT_CMP(fd, !=, -1, "%d");
697 		// Delegate the data transfer to the backend.
698 		USFSTL_ASSERT(dev->ext.server->ops->start_data_transfer);
699 		dev->ext.server->ops->start_data_transfer(&dev->ext, msg.payload.device_state_transfer.transfer_direction, fd);
700 		// Respond with success and the "invalid FD" flag (because we
701 		// didn't include an FD in the response).
702 		msg.payload.u64 = 0x100;
703 		reply_len = sizeof(msg.payload.u64);
704 		break;
705 	}
706 	case VHOST_USER_CHECK_DEVICE_STATE: {
707 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
708 		USFSTL_ASSERT(dev->ext.server->ops->check_data_transfer);
709 	        dev->ext.server->ops->check_data_transfer(&dev->ext);
710 		msg.payload.u64 = 0;
711 		reply_len = sizeof(msg.payload.u64);
712 		break;
713 	}
714 	case VHOST_USER_SNAPSHOT: {
715 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
716 		msg.payload.snapshot_response.bool_store = 1;
717 		reply_len = (int)sizeof(msg.payload.snapshot_response);
718 		break;
719 	}
720 	case VHOST_USER_RESTORE: {
721 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.restore_request));
722 		msg.payload.i8 = 1; // success
723 		reply_len = sizeof(msg.payload.i8);
724 		break;
725 	}
726 	default:
727 		USFSTL_ASSERT(0, "Unsupported message: %d\n", msg.hdr.request);
728 	}
729 
730 	if (reply_len || (msg.hdr.flags & VHOST_USER_MSG_FLAGS_NEED_REPLY)) {
731 		size_t i, tmp;
732 
733 		if (!reply_len) {
734 			msg.payload.u64 = 0;
735 			reply_len = sizeof(uint64_t);
736 		}
737 
738 		msg.hdr.size = reply_len;
739 		msg.hdr.flags &= ~VHOST_USER_MSG_FLAGS_NEED_REPLY;
740 		msg.hdr.flags |= VHOST_USER_MSG_FLAGS_REPLY;
741 
742 		msghdr.msg_control = NULL;
743 		msghdr.msg_controllen = 0;
744 
745 		reply_len += sizeof(msg.hdr);
746 
747 		tmp = reply_len;
748 		for (i = 0; tmp && i < msghdr.msg_iovlen; i++) {
749 			if (tmp <= msg_iov[i].iov_len)
750 				msg_iov[i].iov_len = tmp;
751 			tmp -= msg_iov[i].iov_len;
752 		}
753 		msghdr.msg_iovlen = i;
754 
755 		while (reply_len) {
756 			len = sendmsg(entry->fd, &msghdr, 0);
757 			if (len < 0) {
758 				usfstl_vhost_user_dev_free(dev);
759 				return;
760 			}
761 			USFSTL_ASSERT(len != 0);
762 			reply_len -= len;
763 
764 			for (i = 0; len && i < msghdr.msg_iovlen; i++) {
765 				unsigned int rm = len;
766 
767 				if (msg_iov[i].iov_len <= (size_t)len)
768 					rm = msg_iov[i].iov_len;
769 				len -= rm;
770 				msg_iov[i].iov_len -= rm;
771 				msg_iov[i].iov_base += rm;
772 			}
773 		}
774 	}
775 }
776 
usfstl_vhost_user_connected(int fd,void * data)777 static void usfstl_vhost_user_connected(int fd, void *data)
778 {
779 	struct usfstl_vhost_user_server *server = data;
780 	struct usfstl_vhost_user_dev_int *dev;
781 	unsigned int i;
782 
783 	dev = calloc(1, sizeof(*dev) +
784 			sizeof(dev->virtqs[0]) * server->max_queues);
785 
786 	USFSTL_ASSERT(dev);
787 
788 	for (i = 0; i < server->max_queues; i++) {
789 		dev->virtqs[i].call_fd = -1;
790 		dev->virtqs[i].entry.fd = -1;
791 		dev->virtqs[i].entry.data = dev;
792 		dev->virtqs[i].entry.handler = usfstl_vhost_user_virtq_fdkick;
793 	}
794 
795 	for (i = 0; i < MAX_REGIONS; i++)
796 		dev->region_fds[i] = -1;
797 	dev->req_fd = -1;
798 
799 	dev->ext.server = server;
800 	dev->irq_job.data = dev;
801 	dev->irq_job.name = "vhost-user-irq";
802 	dev->irq_job.priority = 0x10000000;
803 	dev->irq_job.callback = usfstl_vhost_user_job_callback;
804 	usfstl_list_init(&dev->fds);
805 
806 	if (server->ops->connected)
807 		server->ops->connected(&dev->ext);
808 
809 	dev->entry.fd = fd;
810 	dev->entry.handler = usfstl_vhost_user_handle_msg;
811 
812 	usfstl_loop_register(&dev->entry);
813 }
814 
usfstl_vhost_user_server_start(struct usfstl_vhost_user_server * server)815 void usfstl_vhost_user_server_start(struct usfstl_vhost_user_server *server)
816 {
817 	USFSTL_ASSERT(server->ops);
818 	USFSTL_ASSERT(server->socket);
819 
820 	usfstl_uds_create(server->socket, usfstl_vhost_user_connected, server);
821 }
822 
usfstl_vhost_user_server_stop(struct usfstl_vhost_user_server * server)823 void usfstl_vhost_user_server_stop(struct usfstl_vhost_user_server *server)
824 {
825 	usfstl_uds_remove(server->socket);
826 }
827 
usfstl_vhost_user_dev_notify(struct usfstl_vhost_user_dev * extdev,unsigned int virtq_idx,const uint8_t * data,size_t datalen)828 void usfstl_vhost_user_dev_notify(struct usfstl_vhost_user_dev *extdev,
829 				  unsigned int virtq_idx,
830 				  const uint8_t *data, size_t datalen)
831 {
832 	struct usfstl_vhost_user_dev_int *dev;
833 	/* preallocate on the stack for most cases */
834 	struct iovec in_sg[SG_STACK_PREALLOC] = { };
835 	struct iovec out_sg[SG_STACK_PREALLOC] = { };
836 	struct usfstl_vhost_user_buf _buf = {
837 		.in_sg = in_sg,
838 		.n_in_sg = SG_STACK_PREALLOC,
839 		.out_sg = out_sg,
840 		.n_out_sg = SG_STACK_PREALLOC,
841 	};
842 	struct usfstl_vhost_user_buf *buf;
843 
844 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
845 
846 	USFSTL_ASSERT(virtq_idx <= dev->ext.server->max_queues);
847 
848 	if (!dev->virtqs[virtq_idx].enabled)
849 		return;
850 
851 	buf = usfstl_vhost_user_get_virtq_buf(dev, virtq_idx, &_buf);
852 	if (!buf)
853 		return;
854 
855 	USFSTL_ASSERT(buf->n_in_sg && !buf->n_out_sg);
856 	iov_fill(buf->in_sg, buf->n_in_sg, data, datalen);
857 	buf->written = datalen;
858 
859 	usfstl_vhost_user_send_virtq_buf(dev, buf, virtq_idx);
860 	usfstl_vhost_user_free_buf(buf);
861 }
862 
usfstl_vhost_user_config_changed(struct usfstl_vhost_user_dev * dev)863 void usfstl_vhost_user_config_changed(struct usfstl_vhost_user_dev *dev)
864 {
865 	struct usfstl_vhost_user_dev_int *idev;
866 	struct vhost_user_msg msg = {
867 		.hdr.request = VHOST_USER_SLAVE_CONFIG_CHANGE_MSG,
868 		.hdr.flags = VHOST_USER_VERSION,
869 	};
870 
871 	idev = container_of(dev, struct usfstl_vhost_user_dev_int, ext);
872 
873 	if (!(idev->ext.protocol_features &
874 			(1ULL << VHOST_USER_PROTOCOL_F_CONFIG)))
875 		return;
876 
877 	usfstl_vhost_user_send_msg(idev, &msg);
878 }
879 
usfstl_vhost_user_to_va(struct usfstl_vhost_user_dev * extdev,uint64_t addr)880 void *usfstl_vhost_user_to_va(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
881 {
882 	struct usfstl_vhost_user_dev_int *dev;
883 	unsigned int region;
884 
885 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
886 
887 	for (region = 0; region < dev->n_regions; region++) {
888 		if (addr >= dev->regions[region].user_addr &&
889 		    addr < dev->regions[region].user_addr +
890 			   dev->regions[region].size)
891 			return (uint8_t *)dev->region_vaddr[region] +
892 			       (addr -
893 				dev->regions[region].user_addr +
894 				dev->regions[region].mmap_offset);
895 	}
896 	USFSTL_ASSERT(0, "cannot translate user address %"PRIx64"\n", addr);
897 	return NULL;
898 }
899 
usfstl_vhost_phys_to_va(struct usfstl_vhost_user_dev * extdev,uint64_t addr)900 void *usfstl_vhost_phys_to_va(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
901 {
902 	struct usfstl_vhost_user_dev_int *dev;
903 	unsigned int region;
904 
905 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
906 
907 	for (region = 0; region < dev->n_regions; region++) {
908 		if (addr >= dev->regions[region].guest_phys_addr &&
909 		    addr < dev->regions[region].guest_phys_addr +
910 			   dev->regions[region].size)
911 			return (uint8_t *)dev->region_vaddr[region] +
912 			       (addr -
913 				dev->regions[region].guest_phys_addr +
914 				dev->regions[region].mmap_offset);
915 	}
916 
917 	USFSTL_ASSERT(0, "cannot translate physical address %"PRIx64"\n", addr);
918 	return NULL;
919 }
920 
iov_len(struct iovec * sg,unsigned int nsg)921 size_t iov_len(struct iovec *sg, unsigned int nsg)
922 {
923 	size_t len = 0;
924 	unsigned int i;
925 
926 	for (i = 0; i < nsg; i++)
927 		len += sg[i].iov_len;
928 
929 	return len;
930 }
931 
iov_fill(struct iovec * sg,unsigned int nsg,const void * _buf,size_t buflen)932 size_t iov_fill(struct iovec *sg, unsigned int nsg,
933 		const void *_buf, size_t buflen)
934 {
935 	const char *buf = _buf;
936 	unsigned int i;
937 	size_t copied = 0;
938 
939 #define min(a, b) ({ typeof(a) _a = (a); typeof(b) _b = (b); _a < _b ? _a : _b; })
940 	for (i = 0; buflen && i < nsg; i++) {
941 		size_t cpy = min(buflen, sg[i].iov_len);
942 
943 		memcpy(sg[i].iov_base, buf, cpy);
944 		buflen -= cpy;
945 		copied += cpy;
946 		buf += cpy;
947 	}
948 
949 	return copied;
950 }
951 
iov_read(void * _buf,size_t buflen,struct iovec * sg,unsigned int nsg)952 size_t iov_read(void *_buf, size_t buflen,
953 		struct iovec *sg, unsigned int nsg)
954 {
955 	char *buf = _buf;
956 	unsigned int i;
957 	size_t copied = 0;
958 
959 #define min(a, b) ({ typeof(a) _a = (a); typeof(b) _b = (b); _a < _b ? _a : _b; })
960 	for (i = 0; buflen && i < nsg; i++) {
961 		size_t cpy = min(buflen, sg[i].iov_len);
962 
963 		memcpy(buf, sg[i].iov_base, cpy);
964 		buflen -= cpy;
965 		copied += cpy;
966 		buf += cpy;
967 	}
968 
969 	return copied;
970 }
971