xref: /aosp_15_r20/external/virglrenderer/src/proxy/proxy_socket.c (revision bbecb9d118dfdb95f99bd754f8fa9be01f189df3)
1 /*
2  * Copyright 2021 Google LLC
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "proxy_socket.h"
7 
8 #include <poll.h>
9 #include <sys/socket.h>
10 #include <sys/types.h>
11 #include <sys/uio.h>
12 #include <unistd.h>
13 
14 #define PROXY_SOCKET_MAX_FD_COUNT 8
15 
16 /* this is only used when the render server is started on demand */
17 bool
proxy_socket_pair(int out_fds[static2])18 proxy_socket_pair(int out_fds[static 2])
19 {
20    int ret = socketpair(AF_UNIX, SOCK_SEQPACKET, 0, out_fds);
21    if (ret) {
22       proxy_log("failed to create socket pair");
23       return false;
24    }
25 
26    return true;
27 }
28 
29 bool
proxy_socket_is_seqpacket(int fd)30 proxy_socket_is_seqpacket(int fd)
31 {
32    int type;
33    socklen_t len = sizeof(type);
34    if (getsockopt(fd, SOL_SOCKET, SO_TYPE, &type, &len)) {
35       proxy_log("fd %d err %s", fd, strerror(errno));
36       return false;
37    }
38    return type == SOCK_SEQPACKET;
39 }
40 
41 void
proxy_socket_init(struct proxy_socket * socket,int fd)42 proxy_socket_init(struct proxy_socket *socket, int fd)
43 {
44    /* TODO make fd non-blocking and perform io with timeout */
45    assert(fd >= 0);
46    *socket = (struct proxy_socket){
47       .fd = fd,
48    };
49 }
50 
51 void
proxy_socket_fini(struct proxy_socket * socket)52 proxy_socket_fini(struct proxy_socket *socket)
53 {
54    close(socket->fd);
55 }
56 
57 bool
proxy_socket_is_connected(const struct proxy_socket * socket)58 proxy_socket_is_connected(const struct proxy_socket *socket)
59 {
60    struct pollfd poll_fd = {
61       .fd = socket->fd,
62    };
63 
64    while (true) {
65       const int ret = poll(&poll_fd, 1, 0);
66       if (ret == 0) {
67          return true;
68       } else if (ret < 0) {
69          if (errno == EINTR || errno == EAGAIN)
70             continue;
71 
72          proxy_log("failed to poll socket");
73          return false;
74       }
75 
76       if (poll_fd.revents & (POLLERR | POLLHUP | POLLNVAL)) {
77          proxy_log("socket disconnected");
78          return false;
79       }
80 
81       return true;
82    }
83 }
84 
85 static const int *
get_received_fds(const struct msghdr * msg,int * out_count)86 get_received_fds(const struct msghdr *msg, int *out_count)
87 {
88    const struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
89    if (unlikely(!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
90                 cmsg->cmsg_type != SCM_RIGHTS || cmsg->cmsg_len < CMSG_LEN(0))) {
91       *out_count = 0;
92       return NULL;
93    }
94 
95    *out_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
96    return (const int *)CMSG_DATA(cmsg);
97 }
98 
99 static bool
proxy_socket_recvmsg(struct proxy_socket * socket,struct msghdr * msg)100 proxy_socket_recvmsg(struct proxy_socket *socket, struct msghdr *msg)
101 {
102    do {
103       const ssize_t s = recvmsg(socket->fd, msg, MSG_CMSG_CLOEXEC);
104       if (unlikely(s < 0)) {
105          if (errno == EAGAIN || errno == EINTR)
106             continue;
107 
108          proxy_log("failed to receive message: %s", strerror(errno));
109          return false;
110       }
111 
112       assert(msg->msg_iovlen == 1);
113       if (unlikely((msg->msg_flags & (MSG_TRUNC | MSG_CTRUNC)) ||
114                    msg->msg_iov[0].iov_len != (size_t)s)) {
115          proxy_log("failed to receive message: truncated or incomplete");
116 
117          int fd_count;
118          const int *fds = get_received_fds(msg, &fd_count);
119          for (int i = 0; i < fd_count; i++)
120             close(fds[i]);
121 
122          return false;
123       }
124 
125       return true;
126    } while (true);
127 }
128 
129 static bool
proxy_socket_receive_reply_internal(struct proxy_socket * socket,void * data,size_t size,int * fds,int max_fd_count,int * out_fd_count)130 proxy_socket_receive_reply_internal(struct proxy_socket *socket,
131                                     void *data,
132                                     size_t size,
133                                     int *fds,
134                                     int max_fd_count,
135                                     int *out_fd_count)
136 {
137    assert(data && size);
138    struct msghdr msg = {
139       .msg_iov =
140          &(struct iovec){
141             .iov_base = data,
142             .iov_len = size,
143          },
144       .msg_iovlen = 1,
145    };
146 
147    char cmsg_buf[CMSG_SPACE(sizeof(*fds) * PROXY_SOCKET_MAX_FD_COUNT)];
148    if (max_fd_count) {
149       assert(fds && max_fd_count <= PROXY_SOCKET_MAX_FD_COUNT);
150       msg.msg_control = cmsg_buf;
151       msg.msg_controllen = CMSG_SPACE(sizeof(*fds) * max_fd_count);
152 
153       struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
154       memset(cmsg, 0, sizeof(*cmsg));
155    }
156 
157    if (!proxy_socket_recvmsg(socket, &msg))
158       return false;
159 
160    if (max_fd_count) {
161       int received_fd_count;
162       const int *received_fds = get_received_fds(&msg, &received_fd_count);
163       assert(received_fd_count <= max_fd_count);
164 
165       memcpy(fds, received_fds, sizeof(*fds) * received_fd_count);
166       *out_fd_count = received_fd_count;
167    } else if (out_fd_count) {
168       *out_fd_count = 0;
169    }
170 
171    return true;
172 }
173 
174 bool
proxy_socket_receive_reply(struct proxy_socket * socket,void * data,size_t size)175 proxy_socket_receive_reply(struct proxy_socket *socket, void *data, size_t size)
176 {
177    return proxy_socket_receive_reply_internal(socket, data, size, NULL, 0, NULL);
178 }
179 
180 bool
proxy_socket_receive_reply_with_fds(struct proxy_socket * socket,void * data,size_t size,int * fds,int max_fd_count,int * out_fd_count)181 proxy_socket_receive_reply_with_fds(struct proxy_socket *socket,
182                                     void *data,
183                                     size_t size,
184                                     int *fds,
185                                     int max_fd_count,
186                                     int *out_fd_count)
187 {
188    return proxy_socket_receive_reply_internal(socket, data, size, fds, max_fd_count,
189                                               out_fd_count);
190 }
191 
192 static bool
proxy_socket_sendmsg(struct proxy_socket * socket,const struct msghdr * msg)193 proxy_socket_sendmsg(struct proxy_socket *socket, const struct msghdr *msg)
194 {
195    do {
196       const ssize_t s = sendmsg(socket->fd, msg, MSG_NOSIGNAL);
197       if (unlikely(s < 0)) {
198          if (errno == EAGAIN || errno == EINTR)
199             continue;
200 
201          proxy_log("failed to send message: %s", strerror(errno));
202          return false;
203       }
204 
205       /* no partial send since the socket type is SOCK_SEQPACKET */
206       assert(msg->msg_iovlen == 1 && msg->msg_iov[0].iov_len == (size_t)s);
207       return true;
208    } while (true);
209 }
210 
211 static bool
proxy_socket_send_request_internal(struct proxy_socket * socket,const void * data,size_t size,const int * fds,int fd_count)212 proxy_socket_send_request_internal(struct proxy_socket *socket,
213                                    const void *data,
214                                    size_t size,
215                                    const int *fds,
216                                    int fd_count)
217 {
218    assert(data && size);
219    struct msghdr msg = {
220       .msg_iov =
221          &(struct iovec){
222             .iov_base = (void *)data,
223             .iov_len = size,
224          },
225       .msg_iovlen = 1,
226    };
227 
228    char cmsg_buf[CMSG_SPACE(sizeof(*fds) * PROXY_SOCKET_MAX_FD_COUNT)];
229    if (fd_count) {
230       assert(fds && fd_count <= PROXY_SOCKET_MAX_FD_COUNT);
231       msg.msg_control = cmsg_buf;
232       msg.msg_controllen = CMSG_SPACE(sizeof(*fds) * fd_count);
233 
234       struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
235       cmsg->cmsg_level = SOL_SOCKET;
236       cmsg->cmsg_type = SCM_RIGHTS;
237       cmsg->cmsg_len = CMSG_LEN(sizeof(*fds) * fd_count);
238       memcpy(CMSG_DATA(cmsg), fds, sizeof(*fds) * fd_count);
239    }
240 
241    return proxy_socket_sendmsg(socket, &msg);
242 }
243 
244 bool
proxy_socket_send_request(struct proxy_socket * socket,const void * data,size_t size)245 proxy_socket_send_request(struct proxy_socket *socket, const void *data, size_t size)
246 {
247    return proxy_socket_send_request_internal(socket, data, size, NULL, 0);
248 }
249 
250 bool
proxy_socket_send_request_with_fds(struct proxy_socket * socket,const void * data,size_t size,const int * fds,int fd_count)251 proxy_socket_send_request_with_fds(struct proxy_socket *socket,
252                                    const void *data,
253                                    size_t size,
254                                    const int *fds,
255                                    int fd_count)
256 {
257    return proxy_socket_send_request_internal(socket, data, size, fds, fd_count);
258 }
259