xref: /aosp_15_r20/external/pigweed/pw_stream/socket_stream.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_stream/socket_stream.h"
16 
17 #if defined(_WIN32) && _WIN32
18 #include <fcntl.h>
19 #include <io.h>
20 #include <winsock2.h>
21 #include <ws2tcpip.h>
22 #define SHUT_RDWR SD_BOTH
23 #else
24 #include <arpa/inet.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <poll.h>
28 #include <sys/socket.h>
29 #include <sys/types.h>
30 #include <unistd.h>
31 #endif  // defined(_WIN32) && _WIN32
32 
33 #include <cerrno>
34 #include <cstring>
35 
36 #include "pw_assert/check.h"
37 #include "pw_log/log.h"
38 #include "pw_status/status.h"
39 #include "pw_string/to_string.h"
40 
41 namespace pw::stream {
42 namespace {
43 
44 constexpr uint32_t kServerBacklogLength = 1;
45 constexpr const char* kLocalhostAddress = "localhost";
46 
47 // Set necessary options on a socket file descriptor.
ConfigureSocket(int socket)48 void ConfigureSocket([[maybe_unused]] int socket) {
49 #if defined(__APPLE__)
50   // Use SO_NOSIGPIPE to avoid getting a SIGPIPE signal when the remote peer
51   // drops the connection. This is supported on macOS only.
52   constexpr int value = 1;
53   if (setsockopt(socket, SOL_SOCKET, SO_NOSIGPIPE, &value, sizeof(int)) < 0) {
54     PW_LOG_WARN("Failed to set SO_NOSIGPIPE: %s", std::strerror(errno));
55   }
56 #endif  // defined(__APPLE__)
57 }
58 
59 #if defined(_WIN32) && _WIN32
close(SOCKET s)60 int close(SOCKET s) { return closesocket(s); }
61 
write(int fd,const void * buf,size_t count)62 ssize_t write(int fd, const void* buf, size_t count) {
63   return _write(fd, buf, count);
64 }
65 
poll(struct pollfd * fds,unsigned int nfds,int timeout)66 int poll(struct pollfd* fds, unsigned int nfds, int timeout) {
67   return WSAPoll(fds, nfds, timeout);
68 }
69 
pipe(int pipefd[2])70 int pipe(int pipefd[2]) { return _pipe(pipefd, 256, O_BINARY); }
71 
setsockopt(int fd,int level,int optname,const void * optval,unsigned int optlen)72 int setsockopt(
73     int fd, int level, int optname, const void* optval, unsigned int optlen) {
74   return setsockopt(static_cast<SOCKET>(fd),
75                     level,
76                     optname,
77                     static_cast<const char*>(optval),
78                     static_cast<int>(optlen));
79 }
80 
81 class WinsockInitializer {
82  public:
WinsockInitializer()83   WinsockInitializer() {
84     WSADATA data = {};
85     PW_CHECK_INT_EQ(
86         WSAStartup(MAKEWORD(2, 2), &data), 0, "Failed to initialize winsock");
87   }
~WinsockInitializer()88   ~WinsockInitializer() {
89     // TODO: b/301545011 - This currently fails, probably a cleanup race.
90     WSACleanup();
91   }
92 };
93 
94 [[maybe_unused]] WinsockInitializer initializer;
95 
96 #endif  // defined(_WIN32) && _WIN32
97 
98 }  // namespace
99 
Connect(const char * host,uint16_t port)100 Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) {
101   if (host == nullptr) {
102     host = kLocalhostAddress;
103   }
104 
105   struct addrinfo hints = {};
106   struct addrinfo* res;
107   char port_buffer[6];
108   PW_CHECK(ToString(port, port_buffer).ok());
109   hints.ai_family = AF_UNSPEC;
110   hints.ai_socktype = SOCK_STREAM;
111   hints.ai_flags = AI_NUMERICSERV;
112   if (getaddrinfo(host, port_buffer, &hints, &res) != 0) {
113     PW_LOG_ERROR("Failed to configure connection address for socket");
114     return Status::InvalidArgument();
115   }
116 
117   struct addrinfo* rp;
118   int connection_fd;
119   for (rp = res; rp != nullptr; rp = rp->ai_next) {
120     connection_fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
121     if (connection_fd != kInvalidFd) {
122       break;
123     }
124   }
125 
126   if (connection_fd == kInvalidFd) {
127     PW_LOG_ERROR("Failed to create a socket: %s", std::strerror(errno));
128     freeaddrinfo(res);
129     return Status::Unknown();
130   }
131 
132   ConfigureSocket(connection_fd);
133   if (connect(connection_fd, rp->ai_addr, rp->ai_addrlen) == -1) {
134     close(connection_fd);
135     PW_LOG_ERROR(
136         "Failed to connect to %s:%d: %s", host, port, std::strerror(errno));
137     freeaddrinfo(res);
138     return Status::Unknown();
139   }
140 
141   // Mark as ready and take ownership of the connection by this object.
142   {
143     std::lock_guard lock(connection_mutex_);
144     connection_fd_ = connection_fd;
145     TakeConnectionWithLockHeld();
146     ready_ = true;
147   }
148 
149   freeaddrinfo(res);
150   return OkStatus();
151 }
152 
153 // Configures socket options.
SetSockOpt(int level,int optname,const void * optval,unsigned int optlen)154 int SocketStream::SetSockOpt(int level,
155                              int optname,
156                              const void* optval,
157                              unsigned int optlen) {
158   ConnectionOwnership ownership(this);
159   if (ownership.fd() == kInvalidFd) {
160     return EBADF;
161   }
162   return setsockopt(ownership.fd(), level, optname, optval, optlen);
163 }
164 
Close()165 void SocketStream::Close() {
166   ConnectionOwnership ownership(this);
167   {
168     std::lock_guard lock(connection_mutex_);
169     if (ready_) {
170       // Shutdown the connection and send tear down notification to unblock any
171       // waiters.
172       if (connection_fd_ != kInvalidFd) {
173         shutdown(connection_fd_, SHUT_RDWR);
174       }
175       if (connection_pipe_w_fd_ != kInvalidFd) {
176         write(connection_pipe_w_fd_, "T", 1);
177       }
178 
179       // Release ownership of the connection by this object and mark as no
180       // longer ready.
181       ReleaseConnectionWithLockHeld();
182       ready_ = false;
183     }
184   }
185 }
186 
DoWrite(span<const std::byte> data)187 Status SocketStream::DoWrite(span<const std::byte> data) {
188   int send_flags = 0;
189 #if defined(__linux__)
190   // Use MSG_NOSIGNAL to avoid getting a SIGPIPE signal when the remote
191   // peer drops the connection. This is supported on Linux only.
192   send_flags |= MSG_NOSIGNAL;
193 #endif  // defined(__linux__)
194 
195   ssize_t bytes_sent;
196   {
197     ConnectionOwnership ownership(this);
198     if (ownership.fd() == kInvalidFd) {
199       return Status::Unknown();
200     }
201     bytes_sent = send(ownership.fd(),
202                       reinterpret_cast<const char*>(data.data()),
203                       data.size_bytes(),
204                       send_flags);
205   }
206 
207   if (bytes_sent < 0 || static_cast<size_t>(bytes_sent) != data.size()) {
208     if (errno == EPIPE) {
209       // An EPIPE indicates that the connection is closed.  Return an OutOfRange
210       // error.
211       return Status::OutOfRange();
212     }
213 
214     return Status::Unknown();
215   }
216   return OkStatus();
217 }
218 
DoRead(ByteSpan dest)219 StatusWithSize SocketStream::DoRead(ByteSpan dest) {
220   ConnectionOwnership ownership(this);
221   if (ownership.fd() == kInvalidFd) {
222     return StatusWithSize::Unknown();
223   }
224 
225   // Wait for data to read or a tear down notification.
226   pollfd fds_to_poll[2];
227   fds_to_poll[0].fd = ownership.fd();
228   fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP;
229   fds_to_poll[1].fd = ownership.pipe_r_fd();
230   fds_to_poll[1].events = POLLIN;
231   poll(fds_to_poll, 2, -1);
232   if (!(fds_to_poll[0].revents & POLLIN)) {
233     return StatusWithSize::Unknown();
234   }
235 
236   ssize_t bytes_rcvd = recv(ownership.fd(),
237                             reinterpret_cast<char*>(dest.data()),
238                             dest.size_bytes(),
239                             0);
240   if (bytes_rcvd == 0) {
241     // Remote peer has closed the connection.
242     Close();
243     return StatusWithSize::OutOfRange();
244   } else if (bytes_rcvd < 0) {
245     if (errno == EAGAIN || errno == EWOULDBLOCK) {
246       // Socket timed out when trying to read.
247       // This should only occur if SO_RCVTIMEO was configured to be nonzero, or
248       // if the socket was opened with the O_NONBLOCK flag to prevent any
249       // blocking when performing reads or writes.
250       return StatusWithSize::ResourceExhausted();
251     }
252     return StatusWithSize::Unknown();
253   }
254   return StatusWithSize(bytes_rcvd);
255 }
256 
TakeConnection()257 int SocketStream::TakeConnection() {
258   std::lock_guard lock(connection_mutex_);
259   return TakeConnectionWithLockHeld();
260 }
261 
TakeConnectionWithLockHeld()262 int SocketStream::TakeConnectionWithLockHeld() {
263   ++connection_own_count_;
264 
265   if (ready_ && (connection_fd_ != kInvalidFd) &&
266       (connection_pipe_r_fd_ == kInvalidFd)) {
267     int fd_list[2];
268     if (pipe(fd_list) >= 0) {
269       connection_pipe_r_fd_ = fd_list[0];
270       connection_pipe_w_fd_ = fd_list[1];
271     }
272   }
273 
274   if (!ready_ || (connection_pipe_r_fd_ == kInvalidFd) ||
275       (connection_pipe_w_fd_ == kInvalidFd)) {
276     return kInvalidFd;
277   }
278   return connection_fd_;
279 }
280 
ReleaseConnection()281 void SocketStream::ReleaseConnection() {
282   std::lock_guard lock(connection_mutex_);
283   ReleaseConnectionWithLockHeld();
284 }
285 
ReleaseConnectionWithLockHeld()286 void SocketStream::ReleaseConnectionWithLockHeld() {
287   --connection_own_count_;
288 
289   if (connection_own_count_ <= 0) {
290     ready_ = false;
291     if (connection_fd_ != kInvalidFd) {
292       close(connection_fd_);
293       connection_fd_ = kInvalidFd;
294     }
295     if (connection_pipe_r_fd_ != kInvalidFd) {
296       close(connection_pipe_r_fd_);
297       connection_pipe_r_fd_ = kInvalidFd;
298     }
299     if (connection_pipe_w_fd_ != kInvalidFd) {
300       close(connection_pipe_w_fd_);
301       connection_pipe_w_fd_ = kInvalidFd;
302     }
303   }
304 }
305 
306 // Listen for connections on the given port.
307 // If port is 0, a random unused port is chosen and can be retrieved with
308 // port().
Listen(uint16_t port)309 Status ServerSocket::Listen(uint16_t port) {
310   int socket_fd = socket(AF_INET6, SOCK_STREAM, 0);
311   if (socket_fd == kInvalidFd) {
312     return Status::Unknown();
313   }
314 
315   // Allow binding to an address that may still be in use by a closed socket.
316   constexpr int value = 1;
317   setsockopt(socket_fd,
318              SOL_SOCKET,
319              SO_REUSEADDR,
320              reinterpret_cast<const char*>(&value),
321              sizeof(int));
322 
323   if (port != 0) {
324     struct sockaddr_in6 addr = {};
325     socklen_t addr_len = sizeof(addr);
326     addr.sin6_family = AF_INET6;
327     addr.sin6_port = htons(port);
328     addr.sin6_addr = in6addr_any;
329     if (bind(socket_fd, reinterpret_cast<sockaddr*>(&addr), addr_len) < 0) {
330       close(socket_fd);
331       return Status::Unknown();
332     }
333   }
334 
335   if (listen(socket_fd, kServerBacklogLength) < 0) {
336     close(socket_fd);
337     return Status::Unknown();
338   }
339 
340   // Find out which port the socket is listening on, and fill in port_.
341   struct sockaddr_in6 addr = {};
342   socklen_t addr_len = sizeof(addr);
343   if (getsockname(socket_fd, reinterpret_cast<sockaddr*>(&addr), &addr_len) <
344           0 ||
345       static_cast<size_t>(addr_len) > sizeof(addr)) {
346     close(socket_fd);
347     return Status::Unknown();
348   }
349 
350   port_ = ntohs(addr.sin6_port);
351 
352   // Mark as ready and take ownership of the socket by this object.
353   {
354     std::lock_guard lock(socket_mutex_);
355     socket_fd_ = socket_fd;
356     TakeSocketWithLockHeld();
357     ready_ = true;
358   }
359 
360   return OkStatus();
361 }
362 
363 // Accept a connection. Blocks until after a client is connected.
364 // On success, returns a SocketStream connected to the new client.
Accept()365 Result<SocketStream> ServerSocket::Accept() {
366   struct sockaddr_in6 sockaddr_client_ = {};
367   socklen_t len = sizeof(sockaddr_client_);
368 
369   SocketOwnership ownership(this);
370   if (ownership.fd() == kInvalidFd) {
371     return Status::Unknown();
372   }
373 
374   // Wait for a connection or a tear down notification.
375   pollfd fds_to_poll[2];
376   fds_to_poll[0].fd = ownership.fd();
377   fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP;
378   fds_to_poll[1].fd = ownership.pipe_r_fd();
379   fds_to_poll[1].events = POLLIN;
380   int rv = poll(fds_to_poll, 2, -1);
381   if ((rv <= 0) || !(fds_to_poll[0].revents & POLLIN)) {
382     return Status::Unknown();
383   }
384 
385   int connection_fd = accept(
386       ownership.fd(), reinterpret_cast<sockaddr*>(&sockaddr_client_), &len);
387   if (connection_fd == kInvalidFd) {
388     return Status::Unknown();
389   }
390   ConfigureSocket(connection_fd);
391 
392   return SocketStream(connection_fd);
393 }
394 
395 // Close the server socket, preventing further connections.
Close()396 void ServerSocket::Close() {
397   SocketOwnership ownership(this);
398   {
399     std::lock_guard lock(socket_mutex_);
400     if (ready_) {
401       // Shutdown the socket and send tear down notification to unblock any
402       // waiters.
403       if (socket_fd_ != kInvalidFd) {
404         shutdown(socket_fd_, SHUT_RDWR);
405       }
406       if (socket_pipe_w_fd_ != kInvalidFd) {
407         write(socket_pipe_w_fd_, "T", 1);
408       }
409 
410       // Release ownership of the socket by this object and mark as no longer
411       // ready.
412       ReleaseSocketWithLockHeld();
413       ready_ = false;
414     }
415   }
416 }
417 
TakeSocket()418 int ServerSocket::TakeSocket() {
419   std::lock_guard lock(socket_mutex_);
420   return TakeSocketWithLockHeld();
421 }
422 
TakeSocketWithLockHeld()423 int ServerSocket::TakeSocketWithLockHeld() {
424   ++socket_own_count_;
425 
426   if (ready_ && (socket_fd_ != kInvalidFd) &&
427       (socket_pipe_r_fd_ == kInvalidFd)) {
428     int fd_list[2];
429     if (pipe(fd_list) >= 0) {
430       socket_pipe_r_fd_ = fd_list[0];
431       socket_pipe_w_fd_ = fd_list[1];
432     }
433   }
434 
435   if (!ready_ || (socket_pipe_r_fd_ == kInvalidFd) ||
436       (socket_pipe_w_fd_ == kInvalidFd)) {
437     return kInvalidFd;
438   }
439   return socket_fd_;
440 }
441 
ReleaseSocket()442 void ServerSocket::ReleaseSocket() {
443   std::lock_guard lock(socket_mutex_);
444   ReleaseSocketWithLockHeld();
445 }
446 
ReleaseSocketWithLockHeld()447 void ServerSocket::ReleaseSocketWithLockHeld() {
448   --socket_own_count_;
449 
450   if (socket_own_count_ <= 0) {
451     ready_ = false;
452     if (socket_fd_ != kInvalidFd) {
453       close(socket_fd_);
454       socket_fd_ = kInvalidFd;
455     }
456     if (socket_pipe_r_fd_ != kInvalidFd) {
457       close(socket_pipe_r_fd_);
458       socket_pipe_r_fd_ = kInvalidFd;
459     }
460     if (socket_pipe_w_fd_ != kInvalidFd) {
461       close(socket_pipe_w_fd_);
462       socket_pipe_w_fd_ = kInvalidFd;
463     }
464   }
465 }
466 
467 }  // namespace pw::stream
468