xref: /aosp_15_r20/external/cronet/net/socket/unix_domain_client_socket_posix_unittest.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/unix_domain_client_socket_posix.h"
6 
7 #include <unistd.h>
8 
9 #include <memory>
10 #include <utility>
11 
12 #include "base/files/file_path.h"
13 #include "base/files/scoped_temp_dir.h"
14 #include "base/functional/bind.h"
15 #include "base/posix/eintr_wrapper.h"
16 #include "build/build_config.h"
17 #include "net/base/io_buffer.h"
18 #include "net/base/net_errors.h"
19 #include "net/base/sockaddr_storage.h"
20 #include "net/base/sockaddr_util_posix.h"
21 #include "net/base/test_completion_callback.h"
22 #include "net/socket/socket_posix.h"
23 #include "net/socket/unix_domain_server_socket_posix.h"
24 #include "net/test/gtest_util.h"
25 #include "net/test/test_with_task_environment.h"
26 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
27 #include "testing/gmock/include/gmock/gmock.h"
28 #include "testing/gtest/include/gtest/gtest.h"
29 
30 using net::test::IsError;
31 using net::test::IsOk;
32 
33 namespace net {
34 namespace {
35 
36 const char kSocketFilename[] = "socket_for_testing";
37 
UserCanConnectCallback(bool allow_user,const UnixDomainServerSocket::Credentials & credentials)38 bool UserCanConnectCallback(
39     bool allow_user, const UnixDomainServerSocket::Credentials& credentials) {
40   // Here peers are running in same process.
41 #if BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS) || BUILDFLAG(IS_ANDROID)
42   EXPECT_EQ(getpid(), credentials.process_id);
43 #endif
44   EXPECT_EQ(getuid(), credentials.user_id);
45   EXPECT_EQ(getgid(), credentials.group_id);
46   return allow_user;
47 }
48 
CreateAuthCallback(bool allow_user)49 UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) {
50   return base::BindRepeating(&UserCanConnectCallback, allow_user);
51 }
52 
53 // Connects socket synchronously.
ConnectSynchronously(StreamSocket * socket)54 int ConnectSynchronously(StreamSocket* socket) {
55   TestCompletionCallback connect_callback;
56   int rv = socket->Connect(connect_callback.callback());
57   if (rv == ERR_IO_PENDING)
58     rv = connect_callback.WaitForResult();
59   return rv;
60 }
61 
62 // Reads data from |socket| until it fills |buf| at least up to |min_data_len|.
63 // Returns length of data read, or a net error.
ReadSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len,int min_data_len)64 int ReadSynchronously(StreamSocket* socket,
65                       IOBuffer* buf,
66                       int buf_len,
67                       int min_data_len) {
68   DCHECK_LE(min_data_len, buf_len);
69   scoped_refptr<DrainableIOBuffer> read_buf =
70       base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
71   TestCompletionCallback read_callback;
72   // Iterate reading several times (but not infinite) until it reads at least
73   // |min_data_len| bytes into |buf|.
74   for (int retry_count = 10;
75        retry_count > 0 && (read_buf->BytesConsumed() < min_data_len ||
76                            // Try at least once when min_data_len == 0.
77                            min_data_len == 0);
78        --retry_count) {
79     int rv = socket->Read(
80         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
81     EXPECT_GE(read_buf->BytesRemaining(), rv);
82     if (rv == ERR_IO_PENDING) {
83       // If |min_data_len| is 0, returns ERR_IO_PENDING to distinguish the case
84       // when some data has been read.
85       if (min_data_len == 0) {
86         // No data has been read because of for-loop condition.
87         DCHECK_EQ(0, read_buf->BytesConsumed());
88         return ERR_IO_PENDING;
89       }
90       rv = read_callback.WaitForResult();
91     }
92     EXPECT_NE(ERR_IO_PENDING, rv);
93     if (rv < 0)
94       return rv;
95     read_buf->DidConsume(rv);
96   }
97   EXPECT_LE(0, read_buf->BytesRemaining());
98   return read_buf->BytesConsumed();
99 }
100 
101 // Writes data to |socket| until it completes writing |buf| up to |buf_len|.
102 // Returns length of data written, or a net error.
WriteSynchronously(StreamSocket * socket,IOBuffer * buf,int buf_len)103 int WriteSynchronously(StreamSocket* socket,
104                        IOBuffer* buf,
105                        int buf_len) {
106   scoped_refptr<DrainableIOBuffer> write_buf =
107       base::MakeRefCounted<DrainableIOBuffer>(buf, buf_len);
108   TestCompletionCallback write_callback;
109   // Iterate writing several times (but not infinite) until it writes buf fully.
110   for (int retry_count = 10;
111        retry_count > 0 && write_buf->BytesRemaining() > 0;
112        --retry_count) {
113     int rv =
114         socket->Write(write_buf.get(), write_buf->BytesRemaining(),
115                       write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
116     EXPECT_GE(write_buf->BytesRemaining(), rv);
117     if (rv == ERR_IO_PENDING)
118       rv = write_callback.WaitForResult();
119     EXPECT_NE(ERR_IO_PENDING, rv);
120     if (rv < 0)
121       return rv;
122     write_buf->DidConsume(rv);
123   }
124   EXPECT_LE(0, write_buf->BytesRemaining());
125   return write_buf->BytesConsumed();
126 }
127 
128 class UnixDomainClientSocketTest : public TestWithTaskEnvironment {
129  protected:
UnixDomainClientSocketTest()130   UnixDomainClientSocketTest() {
131     EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
132     socket_path_ = temp_dir_.GetPath().Append(kSocketFilename).value();
133   }
134 
135   base::ScopedTempDir temp_dir_;
136   std::string socket_path_;
137 };
138 
TEST_F(UnixDomainClientSocketTest,Connect)139 TEST_F(UnixDomainClientSocketTest, Connect) {
140   const bool kUseAbstractNamespace = false;
141 
142   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
143                                        kUseAbstractNamespace);
144   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
145 
146   std::unique_ptr<StreamSocket> accepted_socket;
147   TestCompletionCallback accept_callback;
148   EXPECT_EQ(ERR_IO_PENDING,
149             server_socket.Accept(&accepted_socket, accept_callback.callback()));
150   EXPECT_FALSE(accepted_socket);
151 
152   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
153   EXPECT_FALSE(client_socket.IsConnected());
154 
155   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
156   EXPECT_TRUE(client_socket.IsConnected());
157   // Server has not yet been notified of the connection.
158   EXPECT_FALSE(accepted_socket);
159 
160   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
161   EXPECT_TRUE(accepted_socket);
162   EXPECT_TRUE(accepted_socket->IsConnected());
163 }
164 
TEST_F(UnixDomainClientSocketTest,ConnectWithSocketDescriptor)165 TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) {
166   const bool kUseAbstractNamespace = false;
167 
168   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
169                                        kUseAbstractNamespace);
170   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
171 
172   SocketDescriptor accepted_socket_fd = kInvalidSocket;
173   TestCompletionCallback accept_callback;
174   EXPECT_EQ(ERR_IO_PENDING,
175             server_socket.AcceptSocketDescriptor(&accepted_socket_fd,
176                                                  accept_callback.callback()));
177   EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
178 
179   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
180   EXPECT_FALSE(client_socket.IsConnected());
181 
182   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
183   EXPECT_TRUE(client_socket.IsConnected());
184   // Server has not yet been notified of the connection.
185   EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
186 
187   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
188   EXPECT_NE(kInvalidSocket, accepted_socket_fd);
189 
190   SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket();
191   EXPECT_NE(kInvalidSocket, client_socket_fd);
192 
193   // Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read
194   // to be sure it hasn't gotten accidentally closed.
195   SockaddrStorage addr;
196   ASSERT_TRUE(FillUnixAddress(socket_path_, false, &addr));
197   auto adopter = std::make_unique<SocketPosix>();
198   adopter->AdoptConnectedSocket(client_socket_fd, addr);
199   UnixDomainClientSocket rewrapped_socket(std::move(adopter));
200   EXPECT_TRUE(rewrapped_socket.IsConnected());
201 
202   // Try to read data.
203   const int kReadDataSize = 10;
204   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
205   TestCompletionCallback read_callback;
206   EXPECT_EQ(ERR_IO_PENDING,
207             rewrapped_socket.Read(
208                 read_buffer.get(), kReadDataSize, read_callback.callback()));
209 
210   EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd)));
211 }
212 
TEST_F(UnixDomainClientSocketTest,ConnectWithAbstractNamespace)213 TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) {
214   const bool kUseAbstractNamespace = true;
215 
216   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
217   EXPECT_FALSE(client_socket.IsConnected());
218 
219 #if BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
220   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
221                                        kUseAbstractNamespace);
222   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
223 
224   std::unique_ptr<StreamSocket> accepted_socket;
225   TestCompletionCallback accept_callback;
226   EXPECT_EQ(ERR_IO_PENDING,
227             server_socket.Accept(&accepted_socket, accept_callback.callback()));
228   EXPECT_FALSE(accepted_socket);
229 
230   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
231   EXPECT_TRUE(client_socket.IsConnected());
232   // Server has not yet beend notified of the connection.
233   EXPECT_FALSE(accepted_socket);
234 
235   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
236   EXPECT_TRUE(accepted_socket);
237   EXPECT_TRUE(accepted_socket->IsConnected());
238 #else
239   EXPECT_THAT(ConnectSynchronously(&client_socket),
240               IsError(ERR_ADDRESS_INVALID));
241 #endif
242 }
243 
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocket)244 TEST_F(UnixDomainClientSocketTest, ConnectToNonExistentSocket) {
245   const bool kUseAbstractNamespace = false;
246 
247   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
248   EXPECT_FALSE(client_socket.IsConnected());
249   EXPECT_THAT(ConnectSynchronously(&client_socket),
250               IsError(ERR_FILE_NOT_FOUND));
251 }
252 
TEST_F(UnixDomainClientSocketTest,ConnectToNonExistentSocketWithAbstractNamespace)253 TEST_F(UnixDomainClientSocketTest,
254        ConnectToNonExistentSocketWithAbstractNamespace) {
255   const bool kUseAbstractNamespace = true;
256 
257   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
258   EXPECT_FALSE(client_socket.IsConnected());
259 
260   TestCompletionCallback connect_callback;
261 #if BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
262   EXPECT_THAT(ConnectSynchronously(&client_socket),
263               IsError(ERR_CONNECTION_REFUSED));
264 #else
265   EXPECT_THAT(ConnectSynchronously(&client_socket),
266               IsError(ERR_ADDRESS_INVALID));
267 #endif
268 }
269 
TEST_F(UnixDomainClientSocketTest,DisconnectFromClient)270 TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) {
271   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
272   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
273   std::unique_ptr<StreamSocket> accepted_socket;
274   TestCompletionCallback accept_callback;
275   EXPECT_EQ(ERR_IO_PENDING,
276             server_socket.Accept(&accepted_socket, accept_callback.callback()));
277   UnixDomainClientSocket client_socket(socket_path_, false);
278   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
279 
280   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
281   EXPECT_TRUE(accepted_socket->IsConnected());
282   EXPECT_TRUE(client_socket.IsConnected());
283 
284   // Try to read data.
285   const int kReadDataSize = 10;
286   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
287   TestCompletionCallback read_callback;
288   EXPECT_EQ(ERR_IO_PENDING,
289             accepted_socket->Read(
290                 read_buffer.get(), kReadDataSize, read_callback.callback()));
291 
292   // Disconnect from client side.
293   client_socket.Disconnect();
294   EXPECT_FALSE(client_socket.IsConnected());
295   EXPECT_FALSE(accepted_socket->IsConnected());
296 
297   // Connection closed by peer.
298   EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
299   // Note that read callback won't be called when the connection is closed
300   // locally before the peer closes it. SocketPosix just clears callbacks.
301 }
302 
TEST_F(UnixDomainClientSocketTest,DisconnectFromServer)303 TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) {
304   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
305   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
306   std::unique_ptr<StreamSocket> accepted_socket;
307   TestCompletionCallback accept_callback;
308   EXPECT_EQ(ERR_IO_PENDING,
309             server_socket.Accept(&accepted_socket, accept_callback.callback()));
310   UnixDomainClientSocket client_socket(socket_path_, false);
311   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
312 
313   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
314   EXPECT_TRUE(accepted_socket->IsConnected());
315   EXPECT_TRUE(client_socket.IsConnected());
316 
317   // Try to read data.
318   const int kReadDataSize = 10;
319   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadDataSize);
320   TestCompletionCallback read_callback;
321   EXPECT_EQ(ERR_IO_PENDING,
322             client_socket.Read(
323                 read_buffer.get(), kReadDataSize, read_callback.callback()));
324 
325   // Disconnect from server side.
326   accepted_socket->Disconnect();
327   EXPECT_FALSE(accepted_socket->IsConnected());
328   EXPECT_FALSE(client_socket.IsConnected());
329 
330   // Connection closed by peer.
331   EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
332   // Note that read callback won't be called when the connection is closed
333   // locally before the peer closes it. SocketPosix just clears callbacks.
334 }
335 
TEST_F(UnixDomainClientSocketTest,ReadAfterWrite)336 TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) {
337   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
338   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
339   std::unique_ptr<StreamSocket> accepted_socket;
340   TestCompletionCallback accept_callback;
341   EXPECT_EQ(ERR_IO_PENDING,
342             server_socket.Accept(&accepted_socket, accept_callback.callback()));
343   UnixDomainClientSocket client_socket(socket_path_, false);
344   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
345 
346   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
347   EXPECT_TRUE(accepted_socket->IsConnected());
348   EXPECT_TRUE(client_socket.IsConnected());
349 
350   // Send data from client to server.
351   const int kWriteDataSize = 10;
352   auto write_buffer =
353       base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
354   EXPECT_EQ(
355       kWriteDataSize,
356       WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
357 
358   // The buffer is bigger than write data size.
359   const int kReadBufferSize = kWriteDataSize * 2;
360   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadBufferSize);
361   EXPECT_EQ(kWriteDataSize,
362             ReadSynchronously(accepted_socket.get(),
363                               read_buffer.get(),
364                               kReadBufferSize,
365                               kWriteDataSize));
366   EXPECT_EQ(std::string(write_buffer->data(), kWriteDataSize),
367             std::string(read_buffer->data(), kWriteDataSize));
368 
369   // Send data from server and client.
370   EXPECT_EQ(kWriteDataSize,
371             WriteSynchronously(
372                 accepted_socket.get(), write_buffer.get(), kWriteDataSize));
373 
374   // Read multiple times.
375   const int kSmallReadBufferSize = kWriteDataSize / 3;
376   EXPECT_EQ(kSmallReadBufferSize,
377             ReadSynchronously(&client_socket,
378                               read_buffer.get(),
379                               kSmallReadBufferSize,
380                               kSmallReadBufferSize));
381   EXPECT_EQ(std::string(write_buffer->data(), kSmallReadBufferSize),
382             std::string(read_buffer->data(), kSmallReadBufferSize));
383 
384   EXPECT_EQ(kWriteDataSize - kSmallReadBufferSize,
385             ReadSynchronously(&client_socket,
386                               read_buffer.get(),
387                               kReadBufferSize,
388                               kWriteDataSize - kSmallReadBufferSize));
389   EXPECT_EQ(std::string(write_buffer->data() + kSmallReadBufferSize,
390                         kWriteDataSize - kSmallReadBufferSize),
391             std::string(read_buffer->data(),
392                         kWriteDataSize - kSmallReadBufferSize));
393 
394   // No more data.
395   EXPECT_EQ(
396       ERR_IO_PENDING,
397       ReadSynchronously(&client_socket, read_buffer.get(), kReadBufferSize, 0));
398 
399   // Disconnect from server side after read-write.
400   accepted_socket->Disconnect();
401   EXPECT_FALSE(accepted_socket->IsConnected());
402   EXPECT_FALSE(client_socket.IsConnected());
403 }
404 
TEST_F(UnixDomainClientSocketTest,ReadBeforeWrite)405 TEST_F(UnixDomainClientSocketTest, ReadBeforeWrite) {
406   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
407   EXPECT_THAT(server_socket.BindAndListen(socket_path_, /*backlog=*/1), IsOk());
408   std::unique_ptr<StreamSocket> accepted_socket;
409   TestCompletionCallback accept_callback;
410   EXPECT_EQ(ERR_IO_PENDING,
411             server_socket.Accept(&accepted_socket, accept_callback.callback()));
412   UnixDomainClientSocket client_socket(socket_path_, false);
413   EXPECT_THAT(ConnectSynchronously(&client_socket), IsOk());
414 
415   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
416   EXPECT_TRUE(accepted_socket->IsConnected());
417   EXPECT_TRUE(client_socket.IsConnected());
418 
419   // Wait for data from client.
420   const int kWriteDataSize = 10;
421   const int kReadBufferSize = kWriteDataSize * 2;
422   const int kSmallReadBufferSize = kWriteDataSize / 3;
423   // Read smaller than write data size first.
424   auto read_buffer = base::MakeRefCounted<IOBufferWithSize>(kReadBufferSize);
425   TestCompletionCallback read_callback;
426   EXPECT_EQ(
427       ERR_IO_PENDING,
428       accepted_socket->Read(
429           read_buffer.get(), kSmallReadBufferSize, read_callback.callback()));
430 
431   auto write_buffer =
432       base::MakeRefCounted<StringIOBuffer>(std::string(kWriteDataSize, 'd'));
433   EXPECT_EQ(
434       kWriteDataSize,
435       WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
436 
437   // First read completed.
438   int rv = read_callback.WaitForResult();
439   EXPECT_LT(0, rv);
440   EXPECT_LE(rv, kSmallReadBufferSize);
441 
442   // Read remaining data.
443   const int kExpectedRemainingDataSize = kWriteDataSize - rv;
444   EXPECT_LE(0, kExpectedRemainingDataSize);
445   EXPECT_EQ(kExpectedRemainingDataSize,
446             ReadSynchronously(accepted_socket.get(),
447                               read_buffer.get(),
448                               kReadBufferSize,
449                               kExpectedRemainingDataSize));
450   // No more data.
451   EXPECT_EQ(ERR_IO_PENDING,
452             ReadSynchronously(
453                 accepted_socket.get(), read_buffer.get(), kReadBufferSize, 0));
454 
455   // Disconnect from server side after read-write.
456   accepted_socket->Disconnect();
457   EXPECT_FALSE(accepted_socket->IsConnected());
458   EXPECT_FALSE(client_socket.IsConnected());
459 }
460 
461 }  // namespace
462 }  // namespace net
463