xref: /aosp_15_r20/external/cronet/net/socket/tcp_server_socket_unittest.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/tcp_server_socket.h"
6 
7 #include <memory>
8 #include <string>
9 #include <vector>
10 
11 #include "base/compiler_specific.h"
12 #include "base/memory/ref_counted.h"
13 #include "net/base/address_list.h"
14 #include "net/base/io_buffer.h"
15 #include "net/base/ip_address.h"
16 #include "net/base/ip_endpoint.h"
17 #include "net/base/net_errors.h"
18 #include "net/base/test_completion_callback.h"
19 #include "net/log/net_log_source.h"
20 #include "net/socket/tcp_client_socket.h"
21 #include "net/test/gtest_util.h"
22 #include "net/test/test_with_task_environment.h"
23 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
24 #include "testing/gmock/include/gmock/gmock.h"
25 #include "testing/gtest/include/gtest/gtest.h"
26 #include "testing/platform_test.h"
27 
28 using net::test::IsError;
29 using net::test::IsOk;
30 
31 namespace net {
32 
33 namespace {
34 const int kListenBacklog = 5;
35 
36 class TCPServerSocketTest : public PlatformTest, public WithTaskEnvironment {
37  protected:
TCPServerSocketTest()38   TCPServerSocketTest() : socket_(nullptr, NetLogSource()) {}
39 
SetUpIPv4()40   void SetUpIPv4() {
41     IPEndPoint address(IPAddress::IPv4Localhost(), 0);
42     ASSERT_THAT(
43         socket_.Listen(address, kListenBacklog, /*ipv6_only=*/std::nullopt),
44         IsOk());
45     ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
46   }
47 
SetUpIPv6(bool * success)48   void SetUpIPv6(bool* success) {
49     *success = false;
50     IPEndPoint address(IPAddress::IPv6Localhost(), 0);
51     if (socket_.Listen(address, kListenBacklog, /*ipv6_only=*/std::nullopt) !=
52         0) {
53       LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
54           "disabled. Skipping the test";
55       return;
56     }
57     ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
58     *success = true;
59   }
60 
SetUpIPv6AllInterfaces(bool ipv6_only)61   void SetUpIPv6AllInterfaces(bool ipv6_only) {
62     IPEndPoint address(IPAddress::IPv6AllZeros(), 0);
63     ASSERT_THAT(socket_.Listen(address, kListenBacklog, ipv6_only), IsOk());
64     ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
65   }
66 
GetPeerAddress(StreamSocket * socket)67   static IPEndPoint GetPeerAddress(StreamSocket* socket) {
68     IPEndPoint address;
69     EXPECT_THAT(socket->GetPeerAddress(&address), IsOk());
70     return address;
71   }
72 
local_address_list() const73   AddressList local_address_list() const {
74     return AddressList(local_address_);
75   }
76 
77   TCPServerSocket socket_;
78   IPEndPoint local_address_;
79 };
80 
TEST_F(TCPServerSocketTest,Accept)81 TEST_F(TCPServerSocketTest, Accept) {
82   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
83 
84   TestCompletionCallback connect_callback;
85   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
86                                     nullptr, NetLogSource());
87   int connect_result = connecting_socket.Connect(connect_callback.callback());
88 
89   TestCompletionCallback accept_callback;
90   std::unique_ptr<StreamSocket> accepted_socket;
91   IPEndPoint peer_address;
92   int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
93                               &peer_address);
94   result = accept_callback.GetResult(result);
95   ASSERT_THAT(result, IsOk());
96 
97   ASSERT_TRUE(accepted_socket.get() != nullptr);
98 
99   // |peer_address| should be correctly populated.
100   EXPECT_EQ(peer_address.address(), local_address_.address());
101 
102   // Both sockets should be on the loopback network interface.
103   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
104             local_address_.address());
105 
106   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
107 }
108 
109 // Test Accept() callback.
TEST_F(TCPServerSocketTest,AcceptAsync)110 TEST_F(TCPServerSocketTest, AcceptAsync) {
111   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
112 
113   TestCompletionCallback accept_callback;
114   std::unique_ptr<StreamSocket> accepted_socket;
115   IPEndPoint peer_address;
116 
117   ASSERT_THAT(socket_.Accept(&accepted_socket, accept_callback.callback(),
118                              &peer_address),
119               IsError(ERR_IO_PENDING));
120 
121   TestCompletionCallback connect_callback;
122   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
123                                     nullptr, NetLogSource());
124   int connect_result = connecting_socket.Connect(connect_callback.callback());
125   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
126 
127   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
128 
129   EXPECT_TRUE(accepted_socket != nullptr);
130 
131   // |peer_address| should be correctly populated.
132   EXPECT_EQ(peer_address.address(), local_address_.address());
133 
134   // Both sockets should be on the loopback network interface.
135   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
136             local_address_.address());
137 }
138 
139 // Test Accept() when client disconnects right after trying to connect.
TEST_F(TCPServerSocketTest,AcceptClientDisconnectAfterConnect)140 TEST_F(TCPServerSocketTest, AcceptClientDisconnectAfterConnect) {
141   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
142 
143   TestCompletionCallback accept_callback;
144   std::unique_ptr<StreamSocket> accepted_socket;
145   IPEndPoint peer_address;
146 
147   TestCompletionCallback connect_callback;
148   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
149                                     nullptr, NetLogSource());
150   int connect_result = connecting_socket.Connect(connect_callback.callback());
151   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
152 
153   int accept_result = socket_.Accept(&accepted_socket,
154                                      accept_callback.callback(), &peer_address);
155   connecting_socket.Disconnect();
156 
157   EXPECT_THAT(accept_callback.GetResult(accept_result), IsOk());
158 
159   EXPECT_TRUE(accepted_socket != nullptr);
160 
161   // |peer_address| should be correctly populated.
162   EXPECT_EQ(peer_address.address(), local_address_.address());
163 }
164 
165 // Accept two connections simultaneously.
TEST_F(TCPServerSocketTest,Accept2Connections)166 TEST_F(TCPServerSocketTest, Accept2Connections) {
167   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
168 
169   TestCompletionCallback accept_callback;
170   std::unique_ptr<StreamSocket> accepted_socket;
171   IPEndPoint peer_address;
172 
173   ASSERT_EQ(ERR_IO_PENDING,
174             socket_.Accept(&accepted_socket, accept_callback.callback(),
175                            &peer_address));
176 
177   TestCompletionCallback connect_callback;
178   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
179                                     nullptr, NetLogSource());
180   int connect_result = connecting_socket.Connect(connect_callback.callback());
181 
182   TestCompletionCallback connect_callback2;
183   TCPClientSocket connecting_socket2(local_address_list(), nullptr, nullptr,
184                                      nullptr, NetLogSource());
185   int connect_result2 =
186       connecting_socket2.Connect(connect_callback2.callback());
187 
188   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
189 
190   TestCompletionCallback accept_callback2;
191   std::unique_ptr<StreamSocket> accepted_socket2;
192   IPEndPoint peer_address2;
193   int result = socket_.Accept(&accepted_socket2, accept_callback2.callback(),
194                               &peer_address2);
195   result = accept_callback2.GetResult(result);
196   ASSERT_THAT(result, IsOk());
197 
198   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
199   EXPECT_THAT(connect_callback2.GetResult(connect_result2), IsOk());
200 
201   EXPECT_TRUE(accepted_socket != nullptr);
202   EXPECT_TRUE(accepted_socket2 != nullptr);
203   EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
204 
205   EXPECT_EQ(peer_address.address(), local_address_.address());
206   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
207             local_address_.address());
208   EXPECT_EQ(peer_address2.address(), local_address_.address());
209   EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(),
210             local_address_.address());
211 }
212 
TEST_F(TCPServerSocketTest,AcceptIPv6)213 TEST_F(TCPServerSocketTest, AcceptIPv6) {
214   bool initialized = false;
215   ASSERT_NO_FATAL_FAILURE(SetUpIPv6(&initialized));
216   if (!initialized)
217     return;
218 
219   TestCompletionCallback connect_callback;
220   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
221                                     nullptr, NetLogSource());
222   int connect_result = connecting_socket.Connect(connect_callback.callback());
223 
224   TestCompletionCallback accept_callback;
225   std::unique_ptr<StreamSocket> accepted_socket;
226   IPEndPoint peer_address;
227   int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
228                               &peer_address);
229   result = accept_callback.GetResult(result);
230   ASSERT_THAT(result, IsOk());
231 
232   ASSERT_TRUE(accepted_socket.get() != nullptr);
233 
234   // |peer_address| should be correctly populated.
235   EXPECT_EQ(peer_address.address(), local_address_.address());
236 
237   // Both sockets should be on the loopback network interface.
238   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
239             local_address_.address());
240 
241   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
242 }
243 
244 class TCPServerSocketTestWithIPv6Only
245     : public TCPServerSocketTest,
246       public testing::WithParamInterface<bool> {
247  public:
AttemptToConnect(const IPAddress & dest_addr,bool should_succeed)248   void AttemptToConnect(const IPAddress& dest_addr, bool should_succeed) {
249     TCPClientSocket connecting_socket(
250         AddressList(IPEndPoint(dest_addr, local_address_.port())), nullptr,
251         nullptr, nullptr, NetLogSource());
252 
253     TestCompletionCallback connect_cb;
254     int connect_result = connecting_socket.Connect(connect_cb.callback());
255     if (!should_succeed) {
256       connect_result = connect_cb.GetResult(connect_result);
257       ASSERT_EQ(connect_result, net::ERR_CONNECTION_REFUSED);
258       return;
259     }
260 
261     std::unique_ptr<StreamSocket> accepted_socket;
262     IPEndPoint peer_address;
263 
264     TestCompletionCallback accept_cb;
265     int accept_result =
266         socket_.Accept(&accepted_socket, accept_cb.callback(), &peer_address);
267     ASSERT_EQ(accept_cb.GetResult(accept_result), net::OK);
268     ASSERT_EQ(connect_cb.GetResult(connect_result), net::OK);
269 
270     // |accepted_socket| should be available.
271     ASSERT_NE(accepted_socket.get(), nullptr);
272 
273     // |peer_address| should be correctly populated.
274     if (peer_address.address().IsIPv4MappedIPv6()) {
275       ASSERT_EQ(ConvertIPv4MappedIPv6ToIPv4(peer_address.address()), dest_addr);
276     } else {
277       ASSERT_EQ(peer_address.address(), dest_addr);
278     }
279   }
280 };
281 
TEST_P(TCPServerSocketTestWithIPv6Only,AcceptIPv6Only)282 TEST_P(TCPServerSocketTestWithIPv6Only, AcceptIPv6Only) {
283   const bool ipv6_only = GetParam();
284   ASSERT_NO_FATAL_FAILURE(SetUpIPv6AllInterfaces(ipv6_only));
285   ASSERT_FALSE(local_address_list().empty());
286 
287   // 127.0.0.1 succeeds when |ipv6_only| is false and vice versa.
288   AttemptToConnect(IPAddress::IPv4Localhost(), /*should_succeed=*/!ipv6_only);
289 
290   // ::1 succeeds regardless of |ipv6_only|.
291   AttemptToConnect(IPAddress::IPv6Localhost(), /*should_succeed=*/true);
292 }
293 
294 INSTANTIATE_TEST_SUITE_P(All, TCPServerSocketTestWithIPv6Only, testing::Bool());
295 
TEST_F(TCPServerSocketTest,AcceptIO)296 TEST_F(TCPServerSocketTest, AcceptIO) {
297   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
298 
299   TestCompletionCallback connect_callback;
300   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
301                                     nullptr, NetLogSource());
302   int connect_result = connecting_socket.Connect(connect_callback.callback());
303 
304   TestCompletionCallback accept_callback;
305   std::unique_ptr<StreamSocket> accepted_socket;
306   IPEndPoint peer_address;
307   int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
308                               &peer_address);
309   ASSERT_THAT(accept_callback.GetResult(result), IsOk());
310 
311   ASSERT_TRUE(accepted_socket.get() != nullptr);
312 
313   // |peer_address| should be correctly populated.
314   EXPECT_EQ(peer_address.address(), local_address_.address());
315 
316   // Both sockets should be on the loopback network interface.
317   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
318             local_address_.address());
319 
320   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
321 
322   const std::string message("test message");
323   std::vector<char> buffer(message.size());
324 
325   size_t bytes_written = 0;
326   while (bytes_written < message.size()) {
327     scoped_refptr<IOBufferWithSize> write_buffer =
328         base::MakeRefCounted<IOBufferWithSize>(message.size() - bytes_written);
329     memmove(write_buffer->data(), message.data(), message.size());
330 
331     TestCompletionCallback write_callback;
332     int write_result = accepted_socket->Write(
333         write_buffer.get(), write_buffer->size(), write_callback.callback(),
334         TRAFFIC_ANNOTATION_FOR_TESTS);
335     write_result = write_callback.GetResult(write_result);
336     ASSERT_TRUE(write_result >= 0);
337     ASSERT_TRUE(bytes_written + write_result <= message.size());
338     bytes_written += write_result;
339   }
340 
341   size_t bytes_read = 0;
342   while (bytes_read < message.size()) {
343     scoped_refptr<IOBufferWithSize> read_buffer =
344         base::MakeRefCounted<IOBufferWithSize>(message.size() - bytes_read);
345     TestCompletionCallback read_callback;
346     int read_result = connecting_socket.Read(
347         read_buffer.get(), read_buffer->size(), read_callback.callback());
348     read_result = read_callback.GetResult(read_result);
349     ASSERT_TRUE(read_result >= 0);
350     ASSERT_TRUE(bytes_read + read_result <= message.size());
351     memmove(&buffer[bytes_read], read_buffer->data(), read_result);
352     bytes_read += read_result;
353   }
354 
355   std::string received_message(buffer.begin(), buffer.end());
356   ASSERT_EQ(message, received_message);
357 }
358 
359 }  // namespace
360 
361 }  // namespace net
362