1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_stream/socket_stream.h"
16
17 #if defined(_WIN32) && _WIN32
18 #include <fcntl.h>
19 #include <io.h>
20 #include <winsock2.h>
21 #include <ws2tcpip.h>
22 #define SHUT_RDWR SD_BOTH
23 #else
24 #include <arpa/inet.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <poll.h>
28 #include <sys/socket.h>
29 #include <sys/types.h>
30 #include <unistd.h>
31 #endif // defined(_WIN32) && _WIN32
32
33 #include <cerrno>
34 #include <cstring>
35
36 #include "pw_assert/check.h"
37 #include "pw_log/log.h"
38 #include "pw_status/status.h"
39 #include "pw_string/to_string.h"
40
41 namespace pw::stream {
42 namespace {
43
44 constexpr uint32_t kServerBacklogLength = 1;
45 constexpr const char* kLocalhostAddress = "localhost";
46
47 // Set necessary options on a socket file descriptor.
ConfigureSocket(int socket)48 void ConfigureSocket([[maybe_unused]] int socket) {
49 #if defined(__APPLE__)
50 // Use SO_NOSIGPIPE to avoid getting a SIGPIPE signal when the remote peer
51 // drops the connection. This is supported on macOS only.
52 constexpr int value = 1;
53 if (setsockopt(socket, SOL_SOCKET, SO_NOSIGPIPE, &value, sizeof(int)) < 0) {
54 PW_LOG_WARN("Failed to set SO_NOSIGPIPE: %s", std::strerror(errno));
55 }
56 #endif // defined(__APPLE__)
57 }
58
59 #if defined(_WIN32) && _WIN32
close(SOCKET s)60 int close(SOCKET s) { return closesocket(s); }
61
write(int fd,const void * buf,size_t count)62 ssize_t write(int fd, const void* buf, size_t count) {
63 return _write(fd, buf, count);
64 }
65
poll(struct pollfd * fds,unsigned int nfds,int timeout)66 int poll(struct pollfd* fds, unsigned int nfds, int timeout) {
67 return WSAPoll(fds, nfds, timeout);
68 }
69
pipe(int pipefd[2])70 int pipe(int pipefd[2]) { return _pipe(pipefd, 256, O_BINARY); }
71
setsockopt(int fd,int level,int optname,const void * optval,unsigned int optlen)72 int setsockopt(
73 int fd, int level, int optname, const void* optval, unsigned int optlen) {
74 return setsockopt(static_cast<SOCKET>(fd),
75 level,
76 optname,
77 static_cast<const char*>(optval),
78 static_cast<int>(optlen));
79 }
80
81 class WinsockInitializer {
82 public:
WinsockInitializer()83 WinsockInitializer() {
84 WSADATA data = {};
85 PW_CHECK_INT_EQ(
86 WSAStartup(MAKEWORD(2, 2), &data), 0, "Failed to initialize winsock");
87 }
~WinsockInitializer()88 ~WinsockInitializer() {
89 // TODO: b/301545011 - This currently fails, probably a cleanup race.
90 WSACleanup();
91 }
92 };
93
94 [[maybe_unused]] WinsockInitializer initializer;
95
96 #endif // defined(_WIN32) && _WIN32
97
98 } // namespace
99
Connect(const char * host,uint16_t port)100 Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) {
101 if (host == nullptr) {
102 host = kLocalhostAddress;
103 }
104
105 struct addrinfo hints = {};
106 struct addrinfo* res;
107 char port_buffer[6];
108 PW_CHECK(ToString(port, port_buffer).ok());
109 hints.ai_family = AF_UNSPEC;
110 hints.ai_socktype = SOCK_STREAM;
111 hints.ai_flags = AI_NUMERICSERV;
112 if (getaddrinfo(host, port_buffer, &hints, &res) != 0) {
113 PW_LOG_ERROR("Failed to configure connection address for socket");
114 return Status::InvalidArgument();
115 }
116
117 struct addrinfo* rp;
118 int connection_fd;
119 for (rp = res; rp != nullptr; rp = rp->ai_next) {
120 connection_fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
121 if (connection_fd != kInvalidFd) {
122 break;
123 }
124 }
125
126 if (connection_fd == kInvalidFd) {
127 PW_LOG_ERROR("Failed to create a socket: %s", std::strerror(errno));
128 freeaddrinfo(res);
129 return Status::Unknown();
130 }
131
132 ConfigureSocket(connection_fd);
133 if (connect(connection_fd, rp->ai_addr, rp->ai_addrlen) == -1) {
134 close(connection_fd);
135 PW_LOG_ERROR(
136 "Failed to connect to %s:%d: %s", host, port, std::strerror(errno));
137 freeaddrinfo(res);
138 return Status::Unknown();
139 }
140
141 // Mark as ready and take ownership of the connection by this object.
142 {
143 std::lock_guard lock(connection_mutex_);
144 connection_fd_ = connection_fd;
145 TakeConnectionWithLockHeld();
146 ready_ = true;
147 }
148
149 freeaddrinfo(res);
150 return OkStatus();
151 }
152
153 // Configures socket options.
SetSockOpt(int level,int optname,const void * optval,unsigned int optlen)154 int SocketStream::SetSockOpt(int level,
155 int optname,
156 const void* optval,
157 unsigned int optlen) {
158 ConnectionOwnership ownership(this);
159 if (ownership.fd() == kInvalidFd) {
160 return EBADF;
161 }
162 return setsockopt(ownership.fd(), level, optname, optval, optlen);
163 }
164
Close()165 void SocketStream::Close() {
166 ConnectionOwnership ownership(this);
167 {
168 std::lock_guard lock(connection_mutex_);
169 if (ready_) {
170 // Shutdown the connection and send tear down notification to unblock any
171 // waiters.
172 if (connection_fd_ != kInvalidFd) {
173 shutdown(connection_fd_, SHUT_RDWR);
174 }
175 if (connection_pipe_w_fd_ != kInvalidFd) {
176 write(connection_pipe_w_fd_, "T", 1);
177 }
178
179 // Release ownership of the connection by this object and mark as no
180 // longer ready.
181 ReleaseConnectionWithLockHeld();
182 ready_ = false;
183 }
184 }
185 }
186
DoWrite(span<const std::byte> data)187 Status SocketStream::DoWrite(span<const std::byte> data) {
188 int send_flags = 0;
189 #if defined(__linux__)
190 // Use MSG_NOSIGNAL to avoid getting a SIGPIPE signal when the remote
191 // peer drops the connection. This is supported on Linux only.
192 send_flags |= MSG_NOSIGNAL;
193 #endif // defined(__linux__)
194
195 ssize_t bytes_sent;
196 {
197 ConnectionOwnership ownership(this);
198 if (ownership.fd() == kInvalidFd) {
199 return Status::Unknown();
200 }
201 bytes_sent = send(ownership.fd(),
202 reinterpret_cast<const char*>(data.data()),
203 data.size_bytes(),
204 send_flags);
205 }
206
207 if (bytes_sent < 0 || static_cast<size_t>(bytes_sent) != data.size()) {
208 if (errno == EPIPE) {
209 // An EPIPE indicates that the connection is closed. Return an OutOfRange
210 // error.
211 return Status::OutOfRange();
212 }
213
214 return Status::Unknown();
215 }
216 return OkStatus();
217 }
218
DoRead(ByteSpan dest)219 StatusWithSize SocketStream::DoRead(ByteSpan dest) {
220 ConnectionOwnership ownership(this);
221 if (ownership.fd() == kInvalidFd) {
222 return StatusWithSize::Unknown();
223 }
224
225 // Wait for data to read or a tear down notification.
226 pollfd fds_to_poll[2];
227 fds_to_poll[0].fd = ownership.fd();
228 fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP;
229 fds_to_poll[1].fd = ownership.pipe_r_fd();
230 fds_to_poll[1].events = POLLIN;
231 poll(fds_to_poll, 2, -1);
232 if (!(fds_to_poll[0].revents & POLLIN)) {
233 return StatusWithSize::Unknown();
234 }
235
236 ssize_t bytes_rcvd = recv(ownership.fd(),
237 reinterpret_cast<char*>(dest.data()),
238 dest.size_bytes(),
239 0);
240 if (bytes_rcvd == 0) {
241 // Remote peer has closed the connection.
242 Close();
243 return StatusWithSize::OutOfRange();
244 } else if (bytes_rcvd < 0) {
245 if (errno == EAGAIN || errno == EWOULDBLOCK) {
246 // Socket timed out when trying to read.
247 // This should only occur if SO_RCVTIMEO was configured to be nonzero, or
248 // if the socket was opened with the O_NONBLOCK flag to prevent any
249 // blocking when performing reads or writes.
250 return StatusWithSize::ResourceExhausted();
251 }
252 return StatusWithSize::Unknown();
253 }
254 return StatusWithSize(bytes_rcvd);
255 }
256
TakeConnection()257 int SocketStream::TakeConnection() {
258 std::lock_guard lock(connection_mutex_);
259 return TakeConnectionWithLockHeld();
260 }
261
TakeConnectionWithLockHeld()262 int SocketStream::TakeConnectionWithLockHeld() {
263 ++connection_own_count_;
264
265 if (ready_ && (connection_fd_ != kInvalidFd) &&
266 (connection_pipe_r_fd_ == kInvalidFd)) {
267 int fd_list[2];
268 if (pipe(fd_list) >= 0) {
269 connection_pipe_r_fd_ = fd_list[0];
270 connection_pipe_w_fd_ = fd_list[1];
271 }
272 }
273
274 if (!ready_ || (connection_pipe_r_fd_ == kInvalidFd) ||
275 (connection_pipe_w_fd_ == kInvalidFd)) {
276 return kInvalidFd;
277 }
278 return connection_fd_;
279 }
280
ReleaseConnection()281 void SocketStream::ReleaseConnection() {
282 std::lock_guard lock(connection_mutex_);
283 ReleaseConnectionWithLockHeld();
284 }
285
ReleaseConnectionWithLockHeld()286 void SocketStream::ReleaseConnectionWithLockHeld() {
287 --connection_own_count_;
288
289 if (connection_own_count_ <= 0) {
290 ready_ = false;
291 if (connection_fd_ != kInvalidFd) {
292 close(connection_fd_);
293 connection_fd_ = kInvalidFd;
294 }
295 if (connection_pipe_r_fd_ != kInvalidFd) {
296 close(connection_pipe_r_fd_);
297 connection_pipe_r_fd_ = kInvalidFd;
298 }
299 if (connection_pipe_w_fd_ != kInvalidFd) {
300 close(connection_pipe_w_fd_);
301 connection_pipe_w_fd_ = kInvalidFd;
302 }
303 }
304 }
305
306 // Listen for connections on the given port.
307 // If port is 0, a random unused port is chosen and can be retrieved with
308 // port().
Listen(uint16_t port)309 Status ServerSocket::Listen(uint16_t port) {
310 int socket_fd = socket(AF_INET6, SOCK_STREAM, 0);
311 if (socket_fd == kInvalidFd) {
312 return Status::Unknown();
313 }
314
315 // Allow binding to an address that may still be in use by a closed socket.
316 constexpr int value = 1;
317 setsockopt(socket_fd,
318 SOL_SOCKET,
319 SO_REUSEADDR,
320 reinterpret_cast<const char*>(&value),
321 sizeof(int));
322
323 if (port != 0) {
324 struct sockaddr_in6 addr = {};
325 socklen_t addr_len = sizeof(addr);
326 addr.sin6_family = AF_INET6;
327 addr.sin6_port = htons(port);
328 addr.sin6_addr = in6addr_any;
329 if (bind(socket_fd, reinterpret_cast<sockaddr*>(&addr), addr_len) < 0) {
330 close(socket_fd);
331 return Status::Unknown();
332 }
333 }
334
335 if (listen(socket_fd, kServerBacklogLength) < 0) {
336 close(socket_fd);
337 return Status::Unknown();
338 }
339
340 // Find out which port the socket is listening on, and fill in port_.
341 struct sockaddr_in6 addr = {};
342 socklen_t addr_len = sizeof(addr);
343 if (getsockname(socket_fd, reinterpret_cast<sockaddr*>(&addr), &addr_len) <
344 0 ||
345 static_cast<size_t>(addr_len) > sizeof(addr)) {
346 close(socket_fd);
347 return Status::Unknown();
348 }
349
350 port_ = ntohs(addr.sin6_port);
351
352 // Mark as ready and take ownership of the socket by this object.
353 {
354 std::lock_guard lock(socket_mutex_);
355 socket_fd_ = socket_fd;
356 TakeSocketWithLockHeld();
357 ready_ = true;
358 }
359
360 return OkStatus();
361 }
362
363 // Accept a connection. Blocks until after a client is connected.
364 // On success, returns a SocketStream connected to the new client.
Accept()365 Result<SocketStream> ServerSocket::Accept() {
366 struct sockaddr_in6 sockaddr_client_ = {};
367 socklen_t len = sizeof(sockaddr_client_);
368
369 SocketOwnership ownership(this);
370 if (ownership.fd() == kInvalidFd) {
371 return Status::Unknown();
372 }
373
374 // Wait for a connection or a tear down notification.
375 pollfd fds_to_poll[2];
376 fds_to_poll[0].fd = ownership.fd();
377 fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP;
378 fds_to_poll[1].fd = ownership.pipe_r_fd();
379 fds_to_poll[1].events = POLLIN;
380 int rv = poll(fds_to_poll, 2, -1);
381 if ((rv <= 0) || !(fds_to_poll[0].revents & POLLIN)) {
382 return Status::Unknown();
383 }
384
385 int connection_fd = accept(
386 ownership.fd(), reinterpret_cast<sockaddr*>(&sockaddr_client_), &len);
387 if (connection_fd == kInvalidFd) {
388 return Status::Unknown();
389 }
390 ConfigureSocket(connection_fd);
391
392 return SocketStream(connection_fd);
393 }
394
395 // Close the server socket, preventing further connections.
Close()396 void ServerSocket::Close() {
397 SocketOwnership ownership(this);
398 {
399 std::lock_guard lock(socket_mutex_);
400 if (ready_) {
401 // Shutdown the socket and send tear down notification to unblock any
402 // waiters.
403 if (socket_fd_ != kInvalidFd) {
404 shutdown(socket_fd_, SHUT_RDWR);
405 }
406 if (socket_pipe_w_fd_ != kInvalidFd) {
407 write(socket_pipe_w_fd_, "T", 1);
408 }
409
410 // Release ownership of the socket by this object and mark as no longer
411 // ready.
412 ReleaseSocketWithLockHeld();
413 ready_ = false;
414 }
415 }
416 }
417
TakeSocket()418 int ServerSocket::TakeSocket() {
419 std::lock_guard lock(socket_mutex_);
420 return TakeSocketWithLockHeld();
421 }
422
TakeSocketWithLockHeld()423 int ServerSocket::TakeSocketWithLockHeld() {
424 ++socket_own_count_;
425
426 if (ready_ && (socket_fd_ != kInvalidFd) &&
427 (socket_pipe_r_fd_ == kInvalidFd)) {
428 int fd_list[2];
429 if (pipe(fd_list) >= 0) {
430 socket_pipe_r_fd_ = fd_list[0];
431 socket_pipe_w_fd_ = fd_list[1];
432 }
433 }
434
435 if (!ready_ || (socket_pipe_r_fd_ == kInvalidFd) ||
436 (socket_pipe_w_fd_ == kInvalidFd)) {
437 return kInvalidFd;
438 }
439 return socket_fd_;
440 }
441
ReleaseSocket()442 void ServerSocket::ReleaseSocket() {
443 std::lock_guard lock(socket_mutex_);
444 ReleaseSocketWithLockHeld();
445 }
446
ReleaseSocketWithLockHeld()447 void ServerSocket::ReleaseSocketWithLockHeld() {
448 --socket_own_count_;
449
450 if (socket_own_count_ <= 0) {
451 ready_ = false;
452 if (socket_fd_ != kInvalidFd) {
453 close(socket_fd_);
454 socket_fd_ = kInvalidFd;
455 }
456 if (socket_pipe_r_fd_ != kInvalidFd) {
457 close(socket_pipe_r_fd_);
458 socket_pipe_r_fd_ = kInvalidFd;
459 }
460 if (socket_pipe_w_fd_ != kInvalidFd) {
461 close(socket_pipe_w_fd_);
462 socket_pipe_w_fd_ = kInvalidFd;
463 }
464 }
465 }
466
467 } // namespace pw::stream
468