1 /*
2 * Copyright 2021 Google LLC
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "proxy_socket.h"
7
8 #include <poll.h>
9 #include <sys/socket.h>
10 #include <sys/types.h>
11 #include <sys/uio.h>
12 #include <unistd.h>
13
14 #define PROXY_SOCKET_MAX_FD_COUNT 8
15
16 /* this is only used when the render server is started on demand */
17 bool
proxy_socket_pair(int out_fds[static2])18 proxy_socket_pair(int out_fds[static 2])
19 {
20 int ret = socketpair(AF_UNIX, SOCK_SEQPACKET, 0, out_fds);
21 if (ret) {
22 proxy_log("failed to create socket pair");
23 return false;
24 }
25
26 return true;
27 }
28
29 bool
proxy_socket_is_seqpacket(int fd)30 proxy_socket_is_seqpacket(int fd)
31 {
32 int type;
33 socklen_t len = sizeof(type);
34 if (getsockopt(fd, SOL_SOCKET, SO_TYPE, &type, &len)) {
35 proxy_log("fd %d err %s", fd, strerror(errno));
36 return false;
37 }
38 return type == SOCK_SEQPACKET;
39 }
40
41 void
proxy_socket_init(struct proxy_socket * socket,int fd)42 proxy_socket_init(struct proxy_socket *socket, int fd)
43 {
44 /* TODO make fd non-blocking and perform io with timeout */
45 assert(fd >= 0);
46 *socket = (struct proxy_socket){
47 .fd = fd,
48 };
49 }
50
51 void
proxy_socket_fini(struct proxy_socket * socket)52 proxy_socket_fini(struct proxy_socket *socket)
53 {
54 close(socket->fd);
55 }
56
57 bool
proxy_socket_is_connected(const struct proxy_socket * socket)58 proxy_socket_is_connected(const struct proxy_socket *socket)
59 {
60 struct pollfd poll_fd = {
61 .fd = socket->fd,
62 };
63
64 while (true) {
65 const int ret = poll(&poll_fd, 1, 0);
66 if (ret == 0) {
67 return true;
68 } else if (ret < 0) {
69 if (errno == EINTR || errno == EAGAIN)
70 continue;
71
72 proxy_log("failed to poll socket");
73 return false;
74 }
75
76 if (poll_fd.revents & (POLLERR | POLLHUP | POLLNVAL)) {
77 proxy_log("socket disconnected");
78 return false;
79 }
80
81 return true;
82 }
83 }
84
85 static const int *
get_received_fds(const struct msghdr * msg,int * out_count)86 get_received_fds(const struct msghdr *msg, int *out_count)
87 {
88 const struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
89 if (unlikely(!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
90 cmsg->cmsg_type != SCM_RIGHTS || cmsg->cmsg_len < CMSG_LEN(0))) {
91 *out_count = 0;
92 return NULL;
93 }
94
95 *out_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
96 return (const int *)CMSG_DATA(cmsg);
97 }
98
99 static bool
proxy_socket_recvmsg(struct proxy_socket * socket,struct msghdr * msg)100 proxy_socket_recvmsg(struct proxy_socket *socket, struct msghdr *msg)
101 {
102 do {
103 const ssize_t s = recvmsg(socket->fd, msg, MSG_CMSG_CLOEXEC);
104 if (unlikely(s < 0)) {
105 if (errno == EAGAIN || errno == EINTR)
106 continue;
107
108 proxy_log("failed to receive message: %s", strerror(errno));
109 return false;
110 }
111
112 assert(msg->msg_iovlen == 1);
113 if (unlikely((msg->msg_flags & (MSG_TRUNC | MSG_CTRUNC)) ||
114 msg->msg_iov[0].iov_len != (size_t)s)) {
115 proxy_log("failed to receive message: truncated or incomplete");
116
117 int fd_count;
118 const int *fds = get_received_fds(msg, &fd_count);
119 for (int i = 0; i < fd_count; i++)
120 close(fds[i]);
121
122 return false;
123 }
124
125 return true;
126 } while (true);
127 }
128
129 static bool
proxy_socket_receive_reply_internal(struct proxy_socket * socket,void * data,size_t size,int * fds,int max_fd_count,int * out_fd_count)130 proxy_socket_receive_reply_internal(struct proxy_socket *socket,
131 void *data,
132 size_t size,
133 int *fds,
134 int max_fd_count,
135 int *out_fd_count)
136 {
137 assert(data && size);
138 struct msghdr msg = {
139 .msg_iov =
140 &(struct iovec){
141 .iov_base = data,
142 .iov_len = size,
143 },
144 .msg_iovlen = 1,
145 };
146
147 char cmsg_buf[CMSG_SPACE(sizeof(*fds) * PROXY_SOCKET_MAX_FD_COUNT)];
148 if (max_fd_count) {
149 assert(fds && max_fd_count <= PROXY_SOCKET_MAX_FD_COUNT);
150 msg.msg_control = cmsg_buf;
151 msg.msg_controllen = CMSG_SPACE(sizeof(*fds) * max_fd_count);
152
153 struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
154 memset(cmsg, 0, sizeof(*cmsg));
155 }
156
157 if (!proxy_socket_recvmsg(socket, &msg))
158 return false;
159
160 if (max_fd_count) {
161 int received_fd_count;
162 const int *received_fds = get_received_fds(&msg, &received_fd_count);
163 assert(received_fd_count <= max_fd_count);
164
165 memcpy(fds, received_fds, sizeof(*fds) * received_fd_count);
166 *out_fd_count = received_fd_count;
167 } else if (out_fd_count) {
168 *out_fd_count = 0;
169 }
170
171 return true;
172 }
173
174 bool
proxy_socket_receive_reply(struct proxy_socket * socket,void * data,size_t size)175 proxy_socket_receive_reply(struct proxy_socket *socket, void *data, size_t size)
176 {
177 return proxy_socket_receive_reply_internal(socket, data, size, NULL, 0, NULL);
178 }
179
180 bool
proxy_socket_receive_reply_with_fds(struct proxy_socket * socket,void * data,size_t size,int * fds,int max_fd_count,int * out_fd_count)181 proxy_socket_receive_reply_with_fds(struct proxy_socket *socket,
182 void *data,
183 size_t size,
184 int *fds,
185 int max_fd_count,
186 int *out_fd_count)
187 {
188 return proxy_socket_receive_reply_internal(socket, data, size, fds, max_fd_count,
189 out_fd_count);
190 }
191
192 static bool
proxy_socket_sendmsg(struct proxy_socket * socket,const struct msghdr * msg)193 proxy_socket_sendmsg(struct proxy_socket *socket, const struct msghdr *msg)
194 {
195 do {
196 const ssize_t s = sendmsg(socket->fd, msg, MSG_NOSIGNAL);
197 if (unlikely(s < 0)) {
198 if (errno == EAGAIN || errno == EINTR)
199 continue;
200
201 proxy_log("failed to send message: %s", strerror(errno));
202 return false;
203 }
204
205 /* no partial send since the socket type is SOCK_SEQPACKET */
206 assert(msg->msg_iovlen == 1 && msg->msg_iov[0].iov_len == (size_t)s);
207 return true;
208 } while (true);
209 }
210
211 static bool
proxy_socket_send_request_internal(struct proxy_socket * socket,const void * data,size_t size,const int * fds,int fd_count)212 proxy_socket_send_request_internal(struct proxy_socket *socket,
213 const void *data,
214 size_t size,
215 const int *fds,
216 int fd_count)
217 {
218 assert(data && size);
219 struct msghdr msg = {
220 .msg_iov =
221 &(struct iovec){
222 .iov_base = (void *)data,
223 .iov_len = size,
224 },
225 .msg_iovlen = 1,
226 };
227
228 char cmsg_buf[CMSG_SPACE(sizeof(*fds) * PROXY_SOCKET_MAX_FD_COUNT)];
229 if (fd_count) {
230 assert(fds && fd_count <= PROXY_SOCKET_MAX_FD_COUNT);
231 msg.msg_control = cmsg_buf;
232 msg.msg_controllen = CMSG_SPACE(sizeof(*fds) * fd_count);
233
234 struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
235 cmsg->cmsg_level = SOL_SOCKET;
236 cmsg->cmsg_type = SCM_RIGHTS;
237 cmsg->cmsg_len = CMSG_LEN(sizeof(*fds) * fd_count);
238 memcpy(CMSG_DATA(cmsg), fds, sizeof(*fds) * fd_count);
239 }
240
241 return proxy_socket_sendmsg(socket, &msg);
242 }
243
244 bool
proxy_socket_send_request(struct proxy_socket * socket,const void * data,size_t size)245 proxy_socket_send_request(struct proxy_socket *socket, const void *data, size_t size)
246 {
247 return proxy_socket_send_request_internal(socket, data, size, NULL, 0);
248 }
249
250 bool
proxy_socket_send_request_with_fds(struct proxy_socket * socket,const void * data,size_t size,const int * fds,int fd_count)251 proxy_socket_send_request_with_fds(struct proxy_socket *socket,
252 const void *data,
253 size_t size,
254 const int *fds,
255 int fd_count)
256 {
257 return proxy_socket_send_request_internal(socket, data, size, fds, fd_count);
258 }
259