1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/tcp_server_socket.h"
6
7 #include <memory>
8 #include <string>
9 #include <vector>
10
11 #include "base/compiler_specific.h"
12 #include "base/memory/ref_counted.h"
13 #include "net/base/address_list.h"
14 #include "net/base/io_buffer.h"
15 #include "net/base/ip_address.h"
16 #include "net/base/ip_endpoint.h"
17 #include "net/base/net_errors.h"
18 #include "net/base/test_completion_callback.h"
19 #include "net/log/net_log_source.h"
20 #include "net/socket/tcp_client_socket.h"
21 #include "net/test/gtest_util.h"
22 #include "net/test/test_with_task_environment.h"
23 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
24 #include "testing/gmock/include/gmock/gmock.h"
25 #include "testing/gtest/include/gtest/gtest.h"
26 #include "testing/platform_test.h"
27
28 using net::test::IsError;
29 using net::test::IsOk;
30
31 namespace net {
32
33 namespace {
34 const int kListenBacklog = 5;
35
36 class TCPServerSocketTest : public PlatformTest, public WithTaskEnvironment {
37 protected:
TCPServerSocketTest()38 TCPServerSocketTest() : socket_(nullptr, NetLogSource()) {}
39
SetUpIPv4()40 void SetUpIPv4() {
41 IPEndPoint address(IPAddress::IPv4Localhost(), 0);
42 ASSERT_THAT(
43 socket_.Listen(address, kListenBacklog, /*ipv6_only=*/std::nullopt),
44 IsOk());
45 ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
46 }
47
SetUpIPv6(bool * success)48 void SetUpIPv6(bool* success) {
49 *success = false;
50 IPEndPoint address(IPAddress::IPv6Localhost(), 0);
51 if (socket_.Listen(address, kListenBacklog, /*ipv6_only=*/std::nullopt) !=
52 0) {
53 LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
54 "disabled. Skipping the test";
55 return;
56 }
57 ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
58 *success = true;
59 }
60
SetUpIPv6AllInterfaces(bool ipv6_only)61 void SetUpIPv6AllInterfaces(bool ipv6_only) {
62 IPEndPoint address(IPAddress::IPv6AllZeros(), 0);
63 ASSERT_THAT(socket_.Listen(address, kListenBacklog, ipv6_only), IsOk());
64 ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
65 }
66
GetPeerAddress(StreamSocket * socket)67 static IPEndPoint GetPeerAddress(StreamSocket* socket) {
68 IPEndPoint address;
69 EXPECT_THAT(socket->GetPeerAddress(&address), IsOk());
70 return address;
71 }
72
local_address_list() const73 AddressList local_address_list() const {
74 return AddressList(local_address_);
75 }
76
77 TCPServerSocket socket_;
78 IPEndPoint local_address_;
79 };
80
TEST_F(TCPServerSocketTest,Accept)81 TEST_F(TCPServerSocketTest, Accept) {
82 ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
83
84 TestCompletionCallback connect_callback;
85 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
86 nullptr, NetLogSource());
87 int connect_result = connecting_socket.Connect(connect_callback.callback());
88
89 TestCompletionCallback accept_callback;
90 std::unique_ptr<StreamSocket> accepted_socket;
91 IPEndPoint peer_address;
92 int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
93 &peer_address);
94 result = accept_callback.GetResult(result);
95 ASSERT_THAT(result, IsOk());
96
97 ASSERT_TRUE(accepted_socket.get() != nullptr);
98
99 // |peer_address| should be correctly populated.
100 EXPECT_EQ(peer_address.address(), local_address_.address());
101
102 // Both sockets should be on the loopback network interface.
103 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
104 local_address_.address());
105
106 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
107 }
108
109 // Test Accept() callback.
TEST_F(TCPServerSocketTest,AcceptAsync)110 TEST_F(TCPServerSocketTest, AcceptAsync) {
111 ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
112
113 TestCompletionCallback accept_callback;
114 std::unique_ptr<StreamSocket> accepted_socket;
115 IPEndPoint peer_address;
116
117 ASSERT_THAT(socket_.Accept(&accepted_socket, accept_callback.callback(),
118 &peer_address),
119 IsError(ERR_IO_PENDING));
120
121 TestCompletionCallback connect_callback;
122 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
123 nullptr, NetLogSource());
124 int connect_result = connecting_socket.Connect(connect_callback.callback());
125 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
126
127 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
128
129 EXPECT_TRUE(accepted_socket != nullptr);
130
131 // |peer_address| should be correctly populated.
132 EXPECT_EQ(peer_address.address(), local_address_.address());
133
134 // Both sockets should be on the loopback network interface.
135 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
136 local_address_.address());
137 }
138
139 // Test Accept() when client disconnects right after trying to connect.
TEST_F(TCPServerSocketTest,AcceptClientDisconnectAfterConnect)140 TEST_F(TCPServerSocketTest, AcceptClientDisconnectAfterConnect) {
141 ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
142
143 TestCompletionCallback accept_callback;
144 std::unique_ptr<StreamSocket> accepted_socket;
145 IPEndPoint peer_address;
146
147 TestCompletionCallback connect_callback;
148 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
149 nullptr, NetLogSource());
150 int connect_result = connecting_socket.Connect(connect_callback.callback());
151 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
152
153 int accept_result = socket_.Accept(&accepted_socket,
154 accept_callback.callback(), &peer_address);
155 connecting_socket.Disconnect();
156
157 EXPECT_THAT(accept_callback.GetResult(accept_result), IsOk());
158
159 EXPECT_TRUE(accepted_socket != nullptr);
160
161 // |peer_address| should be correctly populated.
162 EXPECT_EQ(peer_address.address(), local_address_.address());
163 }
164
165 // Accept two connections simultaneously.
TEST_F(TCPServerSocketTest,Accept2Connections)166 TEST_F(TCPServerSocketTest, Accept2Connections) {
167 ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
168
169 TestCompletionCallback accept_callback;
170 std::unique_ptr<StreamSocket> accepted_socket;
171 IPEndPoint peer_address;
172
173 ASSERT_EQ(ERR_IO_PENDING,
174 socket_.Accept(&accepted_socket, accept_callback.callback(),
175 &peer_address));
176
177 TestCompletionCallback connect_callback;
178 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
179 nullptr, NetLogSource());
180 int connect_result = connecting_socket.Connect(connect_callback.callback());
181
182 TestCompletionCallback connect_callback2;
183 TCPClientSocket connecting_socket2(local_address_list(), nullptr, nullptr,
184 nullptr, NetLogSource());
185 int connect_result2 =
186 connecting_socket2.Connect(connect_callback2.callback());
187
188 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
189
190 TestCompletionCallback accept_callback2;
191 std::unique_ptr<StreamSocket> accepted_socket2;
192 IPEndPoint peer_address2;
193 int result = socket_.Accept(&accepted_socket2, accept_callback2.callback(),
194 &peer_address2);
195 result = accept_callback2.GetResult(result);
196 ASSERT_THAT(result, IsOk());
197
198 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
199 EXPECT_THAT(connect_callback2.GetResult(connect_result2), IsOk());
200
201 EXPECT_TRUE(accepted_socket != nullptr);
202 EXPECT_TRUE(accepted_socket2 != nullptr);
203 EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
204
205 EXPECT_EQ(peer_address.address(), local_address_.address());
206 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
207 local_address_.address());
208 EXPECT_EQ(peer_address2.address(), local_address_.address());
209 EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(),
210 local_address_.address());
211 }
212
TEST_F(TCPServerSocketTest,AcceptIPv6)213 TEST_F(TCPServerSocketTest, AcceptIPv6) {
214 bool initialized = false;
215 ASSERT_NO_FATAL_FAILURE(SetUpIPv6(&initialized));
216 if (!initialized)
217 return;
218
219 TestCompletionCallback connect_callback;
220 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
221 nullptr, NetLogSource());
222 int connect_result = connecting_socket.Connect(connect_callback.callback());
223
224 TestCompletionCallback accept_callback;
225 std::unique_ptr<StreamSocket> accepted_socket;
226 IPEndPoint peer_address;
227 int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
228 &peer_address);
229 result = accept_callback.GetResult(result);
230 ASSERT_THAT(result, IsOk());
231
232 ASSERT_TRUE(accepted_socket.get() != nullptr);
233
234 // |peer_address| should be correctly populated.
235 EXPECT_EQ(peer_address.address(), local_address_.address());
236
237 // Both sockets should be on the loopback network interface.
238 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
239 local_address_.address());
240
241 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
242 }
243
244 class TCPServerSocketTestWithIPv6Only
245 : public TCPServerSocketTest,
246 public testing::WithParamInterface<bool> {
247 public:
AttemptToConnect(const IPAddress & dest_addr,bool should_succeed)248 void AttemptToConnect(const IPAddress& dest_addr, bool should_succeed) {
249 TCPClientSocket connecting_socket(
250 AddressList(IPEndPoint(dest_addr, local_address_.port())), nullptr,
251 nullptr, nullptr, NetLogSource());
252
253 TestCompletionCallback connect_cb;
254 int connect_result = connecting_socket.Connect(connect_cb.callback());
255 if (!should_succeed) {
256 connect_result = connect_cb.GetResult(connect_result);
257 ASSERT_EQ(connect_result, net::ERR_CONNECTION_REFUSED);
258 return;
259 }
260
261 std::unique_ptr<StreamSocket> accepted_socket;
262 IPEndPoint peer_address;
263
264 TestCompletionCallback accept_cb;
265 int accept_result =
266 socket_.Accept(&accepted_socket, accept_cb.callback(), &peer_address);
267 ASSERT_EQ(accept_cb.GetResult(accept_result), net::OK);
268 ASSERT_EQ(connect_cb.GetResult(connect_result), net::OK);
269
270 // |accepted_socket| should be available.
271 ASSERT_NE(accepted_socket.get(), nullptr);
272
273 // |peer_address| should be correctly populated.
274 if (peer_address.address().IsIPv4MappedIPv6()) {
275 ASSERT_EQ(ConvertIPv4MappedIPv6ToIPv4(peer_address.address()), dest_addr);
276 } else {
277 ASSERT_EQ(peer_address.address(), dest_addr);
278 }
279 }
280 };
281
TEST_P(TCPServerSocketTestWithIPv6Only,AcceptIPv6Only)282 TEST_P(TCPServerSocketTestWithIPv6Only, AcceptIPv6Only) {
283 const bool ipv6_only = GetParam();
284 ASSERT_NO_FATAL_FAILURE(SetUpIPv6AllInterfaces(ipv6_only));
285 ASSERT_FALSE(local_address_list().empty());
286
287 // 127.0.0.1 succeeds when |ipv6_only| is false and vice versa.
288 AttemptToConnect(IPAddress::IPv4Localhost(), /*should_succeed=*/!ipv6_only);
289
290 // ::1 succeeds regardless of |ipv6_only|.
291 AttemptToConnect(IPAddress::IPv6Localhost(), /*should_succeed=*/true);
292 }
293
294 INSTANTIATE_TEST_SUITE_P(All, TCPServerSocketTestWithIPv6Only, testing::Bool());
295
TEST_F(TCPServerSocketTest,AcceptIO)296 TEST_F(TCPServerSocketTest, AcceptIO) {
297 ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
298
299 TestCompletionCallback connect_callback;
300 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
301 nullptr, NetLogSource());
302 int connect_result = connecting_socket.Connect(connect_callback.callback());
303
304 TestCompletionCallback accept_callback;
305 std::unique_ptr<StreamSocket> accepted_socket;
306 IPEndPoint peer_address;
307 int result = socket_.Accept(&accepted_socket, accept_callback.callback(),
308 &peer_address);
309 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
310
311 ASSERT_TRUE(accepted_socket.get() != nullptr);
312
313 // |peer_address| should be correctly populated.
314 EXPECT_EQ(peer_address.address(), local_address_.address());
315
316 // Both sockets should be on the loopback network interface.
317 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
318 local_address_.address());
319
320 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
321
322 const std::string message("test message");
323 std::vector<char> buffer(message.size());
324
325 size_t bytes_written = 0;
326 while (bytes_written < message.size()) {
327 scoped_refptr<IOBufferWithSize> write_buffer =
328 base::MakeRefCounted<IOBufferWithSize>(message.size() - bytes_written);
329 memmove(write_buffer->data(), message.data(), message.size());
330
331 TestCompletionCallback write_callback;
332 int write_result = accepted_socket->Write(
333 write_buffer.get(), write_buffer->size(), write_callback.callback(),
334 TRAFFIC_ANNOTATION_FOR_TESTS);
335 write_result = write_callback.GetResult(write_result);
336 ASSERT_TRUE(write_result >= 0);
337 ASSERT_TRUE(bytes_written + write_result <= message.size());
338 bytes_written += write_result;
339 }
340
341 size_t bytes_read = 0;
342 while (bytes_read < message.size()) {
343 scoped_refptr<IOBufferWithSize> read_buffer =
344 base::MakeRefCounted<IOBufferWithSize>(message.size() - bytes_read);
345 TestCompletionCallback read_callback;
346 int read_result = connecting_socket.Read(
347 read_buffer.get(), read_buffer->size(), read_callback.callback());
348 read_result = read_callback.GetResult(read_result);
349 ASSERT_TRUE(read_result >= 0);
350 ASSERT_TRUE(bytes_read + read_result <= message.size());
351 memmove(&buffer[bytes_read], read_buffer->data(), read_result);
352 bytes_read += read_result;
353 }
354
355 std::string received_message(buffer.begin(), buffer.end());
356 ASSERT_EQ(message, received_message);
357 }
358
359 } // namespace
360
361 } // namespace net
362