1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/unix_domain_client_socket_posix.h"
6
7 #include <unistd.h>
8
9 #include <memory>
10 #include <utility>
11
12 #include "base/files/file_path.h"
13 #include "base/files/scoped_temp_dir.h"
14 #include "base/functional/bind.h"
15 #include "base/posix/eintr_wrapper.h"
16 #include "build/build_config.h"
17 #include "net/base/io_buffer.h"
18 #include "net/base/net_errors.h"
19 #include "net/base/sockaddr_storage.h"
20 #include "net/base/sockaddr_util_posix.h"
21 #include "net/base/test_completion_callback.h"
22 #include "net/socket/socket_posix.h"
23 #include "net/socket/unix_domain_server_socket_posix.h"
24 #include "net/test/gtest_util.h"
25 #include "net/test/test_with_task_environment.h"
26 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
27 #include "testing/gmock/include/gmock/gmock.h"
28 #include "testing/gtest/include/gtest/gtest.h"
29
30 using net::test::IsError;
31 using net::test::IsOk;
32
33 namespace net {
34 namespace {
35
36 const char kSocketFilename[] = "socket_for_testing";
37
UserCanConnectCallback(bool allow_user,const UnixDomainServerSocket::Credentials & credentials)38 bool UserCanConnectCallback(
39 bool allow_user, const UnixDomainServerSocket::Credentials& credentials) {
40 // Here peers are running in same process.
41 #if BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS) || BUILDFLAG(IS_ANDROID)
42 EXPECT_EQ(getpid(), credentials.process_id);
43 #endif
44 EXPECT_EQ(getuid(), credentials.user_id);
45 EXPECT_EQ(getgid(), credentials.group_id);
46 return allow_user;
47 }
48
CreateAuthCallback(bool allow_user)49 UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) {
50 return base::BindRepeating(&UserCanConnectCallback, allow_user);
51 }
52
53 // Connects socket synchronously.
ConnectSynchronously(StreamSocket * socket)54 int ConnectSynchronously(StreamSocket* socket) {
55 TestCompletionCallback connect_callback;
56 int rv = socket->Connect(connect_callback.callback());
57 if (rv == ERR_IO_PENDING)
58 rv = connect_callback.WaitForResult();
59 return rv;
60 }
61
62 // Reads data from |socket| until it fills |buf| at least up to |min_data_len|.
63 // Returns length of data read, or a net error.
ReadSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len,int min_data_len)64 int ReadSynchronously(StreamSocket* socket,
65 IOBuffer* buf,
66 int buf_len,
67 int min_data_len) {
68 DCHECK_LE(min_data_len, buf_len);
69 scoped_refptr<DrainableIOBuffer> read_buf =
70 base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
71 TestCompletionCallback read_callback;
72 // Iterate reading several times (but not infinite) until it reads at least
73 // |min_data_len| bytes into |buf|.
74 for (int retry_count = 10;
75 retry_count > 0 && (read_buf->BytesConsumed() < min_data_len ||
76 // Try at least once when min_data_len == 0.
77 min_data_len == 0);
78 --retry_count) {
79 int rv = socket->Read(
80 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
81 EXPECT_GE(read_buf->BytesRemaining(), rv);
82 if (rv == ERR_IO_PENDING) {
83 // If |min_data_len| is 0, returns ERR_IO_PENDING to distinguish the case
84 // when some data has been read.
85 if (min_data_len == 0) {
86 // No data has been read because of for-loop condition.
87 DCHECK_EQ(0, read_buf->BytesConsumed());
88 return ERR_IO_PENDING;
89 }
90 rv = read_callback.WaitForResult();
91 }
92 EXPECT_NE(ERR_IO_PENDING, rv);
93 if (rv < 0)
94 return rv;
95 read_buf->DidConsume(rv);
96 }
97 EXPECT_LE(0, read_buf->BytesRemaining());
98 return read_buf->BytesConsumed();
99 }
100
101 // Writes data to |socket| until it completes writing |buf| up to |buf_len|.
102 // Returns length of data written, or a net error.
WriteSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len)103 int WriteSynchronously(StreamSocket* socket,
104 IOBuffer* buf,
105 int buf_len) {
106 scoped_refptr<DrainableIOBuffer> write_buf =
107 base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
108 TestCompletionCallback write_callback;
109 // Iterate writing several times (but not infinite) until it writes buf fully.
110 for (int retry_count = 10;
111 retry_count > 0 && write_buf->BytesRemaining() > 0;
112 --retry_count) {
113 int rv =
114 socket->Write(write_buf.get(), write_buf->BytesRemaining(),
115 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
116 EXPECT_GE(write_buf->BytesRemaining(), rv);
117 if (rv == ERR_IO_PENDING)
118 rv = write_callback.WaitForResult();
119 EXPECT_NE(ERR_IO_PENDING, rv);
120 if (rv < 0)
121 return rv;
122 write_buf->DidConsume(rv);
123 }
124 EXPECT_LE(0, write_buf->BytesRemaining());
125 return write_buf->BytesConsumed();
126 }
127
128 class UnixDomainClientSocketTest : public TestWithTaskEnvironment {
129 protected:
UnixDomainClientSocketTest()130 UnixDomainClientSocketTest() {
131 EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
132 socket_path_ = temp_dir_.GetPath().Append(kSocketFilename).value();
133 }
134
135 base::ScopedTempDir temp_dir_;
136 std::string socket_path_;
137 };
138
TEST_F(UnixDomainClientSocketTest,Connect)139 TEST_F(UnixDomainClientSocketTest, Connect) {
140 const bool kUseAbstractNamespace = false;
141
142 UnixDomainServerSocket server_socket(CreateAuthCallback(true),
143 kUseAbstractNamespace);
144 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
145
146 std::unique_ptr<StreamSocket> accepted_socket;
147 TestCompletionCallback accept_callback;
148 EXPECT_EQ(ERR_IO_PENDING,
149 server_socket.Accept(&accepted_socket, accept_callback.callback()));
150 EXPECT_FALSE(accepted_socket);
151
152 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
153 EXPECT_FALSE(client_socket.IsConnected());
154
155 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
156 EXPECT_TRUE(client_socket.IsConnected());
157 // Server has not yet been notified of the connection.
158 EXPECT_FALSE(accepted_socket);
159
160 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
161 EXPECT_TRUE(accepted_socket);
162 EXPECT_TRUE(accepted_socket->IsConnected());
163 }
164
TEST_F(UnixDomainClientSocketTest,ConnectWithSocketDescriptor)165 TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) {
166 const bool kUseAbstractNamespace = false;
167
168 UnixDomainServerSocket server_socket(CreateAuthCallback(true),
169 kUseAbstractNamespace);
170 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
171
172 SocketDescriptor accepted_socket_fd = kInvalidSocket;
173 TestCompletionCallback accept_callback;
174 EXPECT_EQ(ERR_IO_PENDING,
175 server_socket.AcceptSocketDescriptor(&accepted_socket_fd,
176 accept_callback.callback()));
177 EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
178
179 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
180 EXPECT_FALSE(client_socket.IsConnected());
181
182 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
183 EXPECT_TRUE(client_socket.IsConnected());
184 // Server has not yet been notified of the connection.
185 EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
186
187 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
188 EXPECT_NE(kInvalidSocket, accepted_socket_fd);
189
190 SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket();
191 EXPECT_NE(kInvalidSocket, client_socket_fd);
192
193 // Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read
194 // to be sure it hasn't gotten accidentally closed.
195 SockaddrStorage addr;
196 ASSERT_TRUE(FillUnixAddress(socket_path_, false, &addr));
197 auto adopter = std::make_unique<SocketPosix>();
198 adopter->AdoptConnectedSocket(client_socket_fd, addr);
199 UnixDomainClientSocket rewrapped_socket(std::move(adopter));
200 EXPECT_TRUE(rewrapped_socket.IsConnected());
201
202 // Try to read data.
203 const int kReadDataSize = 10;
204 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
205 TestCompletionCallback read_callback;
206 EXPECT_EQ(ERR_IO_PENDING,
207 rewrapped_socket.Read(
208 read_buffer.get(), kReadDataSize, read_callback.callback()));
209
210 EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd)));
211 }
212
TEST_F(UnixDomainClientSocketTest,ConnectWithAbstractNamespace)213 TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) {
214 const bool kUseAbstractNamespace = true;
215
216 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
217 EXPECT_FALSE(client_socket.IsConnected());
218
219 #if BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
220 UnixDomainServerSocket server_socket(CreateAuthCallback(true),
221 kUseAbstractNamespace);
222 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
223
224 std::unique_ptr<StreamSocket> accepted_socket;
225 TestCompletionCallback accept_callback;
226 EXPECT_EQ(ERR_IO_PENDING,
227 server_socket.Accept(&accepted_socket, accept_callback.callback()));
228 EXPECT_FALSE(accepted_socket);
229
230 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
231 EXPECT_TRUE(client_socket.IsConnected());
232 // Server has not yet beend notified of the connection.
233 EXPECT_FALSE(accepted_socket);
234
235 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
236 EXPECT_TRUE(accepted_socket);
237 EXPECT_TRUE(accepted_socket->IsConnected());
238 #else
239 EXPECT_THAT(ConnectSynchronously(&client_socket),
240 IsError(ERR_ADDRESS_INVALID));
241 #endif
242 }
243
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocket)244 TEST_F(UnixDomainClientSocketTest, ConnectToNonExistentSocket) {
245 const bool kUseAbstractNamespace = false;
246
247 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
248 EXPECT_FALSE(client_socket.IsConnected());
249 EXPECT_THAT(ConnectSynchronously(&client_socket),
250 IsError(ERR_FILE_NOT_FOUND));
251 }
252
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocketWithAbstractNamespace)253 TEST_F(UnixDomainClientSocketTest,
254 ConnectToNonExistentSocketWithAbstractNamespace) {
255 const bool kUseAbstractNamespace = true;
256
257 UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
258 EXPECT_FALSE(client_socket.IsConnected());
259
260 TestCompletionCallback connect_callback;
261 #if BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
262 EXPECT_THAT(ConnectSynchronously(&client_socket),
263 IsError(ERR_CONNECTION_REFUSED));
264 #else
265 EXPECT_THAT(ConnectSynchronously(&client_socket),
266 IsError(ERR_ADDRESS_INVALID));
267 #endif
268 }
269
TEST_F(UnixDomainClientSocketTest,DisconnectFromClient)270 TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) {
271 UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
272 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
273 std::unique_ptr<StreamSocket> accepted_socket;
274 TestCompletionCallback accept_callback;
275 EXPECT_EQ(ERR_IO_PENDING,
276 server_socket.Accept(&accepted_socket, accept_callback.callback()));
277 UnixDomainClientSocket client_socket(socket_path_, false);
278 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
279
280 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
281 EXPECT_TRUE(accepted_socket->IsConnected());
282 EXPECT_TRUE(client_socket.IsConnected());
283
284 // Try to read data.
285 const int kReadDataSize = 10;
286 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
287 TestCompletionCallback read_callback;
288 EXPECT_EQ(ERR_IO_PENDING,
289 accepted_socket->Read(
290 read_buffer.get(), kReadDataSize, read_callback.callback()));
291
292 // Disconnect from client side.
293 client_socket.Disconnect();
294 EXPECT_FALSE(client_socket.IsConnected());
295 EXPECT_FALSE(accepted_socket->IsConnected());
296
297 // Connection closed by peer.
298 EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
299 // Note that read callback won't be called when the connection is closed
300 // locally before the peer closes it. SocketPosix just clears callbacks.
301 }
302
TEST_F(UnixDomainClientSocketTest,DisconnectFromServer)303 TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) {
304 UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
305 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
306 std::unique_ptr<StreamSocket> accepted_socket;
307 TestCompletionCallback accept_callback;
308 EXPECT_EQ(ERR_IO_PENDING,
309 server_socket.Accept(&accepted_socket, accept_callback.callback()));
310 UnixDomainClientSocket client_socket(socket_path_, false);
311 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
312
313 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
314 EXPECT_TRUE(accepted_socket->IsConnected());
315 EXPECT_TRUE(client_socket.IsConnected());
316
317 // Try to read data.
318 const int kReadDataSize = 10;
319 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
320 TestCompletionCallback read_callback;
321 EXPECT_EQ(ERR_IO_PENDING,
322 client_socket.Read(
323 read_buffer.get(), kReadDataSize, read_callback.callback()));
324
325 // Disconnect from server side.
326 accepted_socket->Disconnect();
327 EXPECT_FALSE(accepted_socket->IsConnected());
328 EXPECT_FALSE(client_socket.IsConnected());
329
330 // Connection closed by peer.
331 EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
332 // Note that read callback won't be called when the connection is closed
333 // locally before the peer closes it. SocketPosix just clears callbacks.
334 }
335
TEST_F(UnixDomainClientSocketTest,ReadAfterWrite)336 TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) {
337 UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
338 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
339 std::unique_ptr<StreamSocket> accepted_socket;
340 TestCompletionCallback accept_callback;
341 EXPECT_EQ(ERR_IO_PENDING,
342 server_socket.Accept(&accepted_socket, accept_callback.callback()));
343 UnixDomainClientSocket client_socket(socket_path_, false);
344 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
345
346 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
347 EXPECT_TRUE(accepted_socket->IsConnected());
348 EXPECT_TRUE(client_socket.IsConnected());
349
350 // Send data from client to server.
351 const int kWriteDataSize = 10;
352 auto write_buffer =
353 base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
354 EXPECT_EQ(
355 kWriteDataSize,
356 WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
357
358 // The buffer is bigger than write data size.
359 const int kReadBufferSize = kWriteDataSize * 2;
360 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadBufferSize);
361 EXPECT_EQ(kWriteDataSize,
362 ReadSynchronously(accepted_socket.get(),
363 read_buffer.get(),
364 kReadBufferSize,
365 kWriteDataSize));
366 EXPECT_EQ(std::string(write_buffer->data(), kWriteDataSize),
367 std::string(read_buffer->data(), kWriteDataSize));
368
369 // Send data from server and client.
370 EXPECT_EQ(kWriteDataSize,
371 WriteSynchronously(
372 accepted_socket.get(), write_buffer.get(), kWriteDataSize));
373
374 // Read multiple times.
375 const int kSmallReadBufferSize = kWriteDataSize / 3;
376 EXPECT_EQ(kSmallReadBufferSize,
377 ReadSynchronously(&client_socket,
378 read_buffer.get(),
379 kSmallReadBufferSize,
380 kSmallReadBufferSize));
381 EXPECT_EQ(std::string(write_buffer->data(), kSmallReadBufferSize),
382 std::string(read_buffer->data(), kSmallReadBufferSize));
383
384 EXPECT_EQ(kWriteDataSize - kSmallReadBufferSize,
385 ReadSynchronously(&client_socket,
386 read_buffer.get(),
387 kReadBufferSize,
388 kWriteDataSize - kSmallReadBufferSize));
389 EXPECT_EQ(std::string(write_buffer->data() + kSmallReadBufferSize,
390 kWriteDataSize - kSmallReadBufferSize),
391 std::string(read_buffer->data(),
392 kWriteDataSize - kSmallReadBufferSize));
393
394 // No more data.
395 EXPECT_EQ(
396 ERR_IO_PENDING,
397 ReadSynchronously(&client_socket, read_buffer.get(), kReadBufferSize, 0));
398
399 // Disconnect from server side after read-write.
400 accepted_socket->Disconnect();
401 EXPECT_FALSE(accepted_socket->IsConnected());
402 EXPECT_FALSE(client_socket.IsConnected());
403 }
404
TEST_F(UnixDomainClientSocketTest,ReadBeforeWrite)405 TEST_F(UnixDomainClientSocketTest, ReadBeforeWrite) {
406 UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
407 EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
408 std::unique_ptr<StreamSocket> accepted_socket;
409 TestCompletionCallback accept_callback;
410 EXPECT_EQ(ERR_IO_PENDING,
411 server_socket.Accept(&accepted_socket, accept_callback.callback()));
412 UnixDomainClientSocket client_socket(socket_path_, false);
413 EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
414
415 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
416 EXPECT_TRUE(accepted_socket->IsConnected());
417 EXPECT_TRUE(client_socket.IsConnected());
418
419 // Wait for data from client.
420 const int kWriteDataSize = 10;
421 const int kReadBufferSize = kWriteDataSize * 2;
422 const int kSmallReadBufferSize = kWriteDataSize / 3;
423 // Read smaller than write data size first.
424 auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadBufferSize);
425 TestCompletionCallback read_callback;
426 EXPECT_EQ(
427 ERR_IO_PENDING,
428 accepted_socket->Read(
429 read_buffer.get(), kSmallReadBufferSize, read_callback.callback()));
430
431 auto write_buffer =
432 base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
433 EXPECT_EQ(
434 kWriteDataSize,
435 WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
436
437 // First read completed.
438 int rv = read_callback.WaitForResult();
439 EXPECT_LT(0, rv);
440 EXPECT_LE(rv, kSmallReadBufferSize);
441
442 // Read remaining data.
443 const int kExpectedRemainingDataSize = kWriteDataSize - rv;
444 EXPECT_LE(0, kExpectedRemainingDataSize);
445 EXPECT_EQ(kExpectedRemainingDataSize,
446 ReadSynchronously(accepted_socket.get(),
447 read_buffer.get(),
448 kReadBufferSize,
449 kExpectedRemainingDataSize));
450 // No more data.
451 EXPECT_EQ(ERR_IO_PENDING,
452 ReadSynchronously(
453 accepted_socket.get(), read_buffer.get(), kReadBufferSize, 0));
454
455 // Disconnect from server side after read-write.
456 accepted_socket->Disconnect();
457 EXPECT_FALSE(accepted_socket->IsConnected());
458 EXPECT_FALSE(client_socket.IsConnected());
459 }
460
461 } // namespace
462 } // namespace net
463