1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/tcp_socket.h"
6
7 #include <stddef.h>
8 #include <string.h>
9
10 #include <memory>
11 #include <string>
12 #include <vector>
13
14 #include "base/functional/bind.h"
15 #include "base/memory/ref_counted.h"
16 #include "base/test/bind.h"
17 #include "base/time/time.h"
18 #include "build/build_config.h"
19 #include "net/base/address_list.h"
20 #include "net/base/io_buffer.h"
21 #include "net/base/ip_endpoint.h"
22 #include "net/base/net_errors.h"
23 #include "net/base/sockaddr_storage.h"
24 #include "net/base/sys_addrinfo.h"
25 #include "net/base/test_completion_callback.h"
26 #include "net/log/net_log_source.h"
27 #include "net/socket/socket_descriptor.h"
28 #include "net/socket/socket_performance_watcher.h"
29 #include "net/socket/socket_test_util.h"
30 #include "net/socket/tcp_client_socket.h"
31 #include "net/test/embedded_test_server/embedded_test_server.h"
32 #include "net/test/gtest_util.h"
33 #include "net/test/test_with_task_environment.h"
34 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
35 #include "testing/gmock/include/gmock/gmock.h"
36 #include "testing/gtest/include/gtest/gtest.h"
37 #include "testing/platform_test.h"
38
39 #if BUILDFLAG(IS_ANDROID)
40 #include "base/android/build_info.h"
41 #include "net/android/network_change_notifier_factory_android.h"
42 #include "net/base/network_change_notifier.h"
43 #endif // BUILDFLAG(IS_ANDROID)
44
45 // For getsockopt() call.
46 #if BUILDFLAG(IS_WIN)
47 #include <winsock2.h>
48 #else // !BUILDFLAG(IS_WIN)
49 #include <sys/socket.h>
50 #endif // !BUILDFLAG(IS_WIN)
51
52 using net::test::IsError;
53 using net::test::IsOk;
54
55 namespace net {
56
57 namespace {
58
59 // IOBuffer with the ability to invoke a callback when destroyed. Useful for
60 // checking for leaks.
61 class IOBufferWithDestructionCallback : public IOBufferWithSize {
62 public:
IOBufferWithDestructionCallback(base::OnceClosure on_destroy_closure)63 explicit IOBufferWithDestructionCallback(base::OnceClosure on_destroy_closure)
64 : IOBufferWithSize(1024),
65 on_destroy_closure_(std::move(on_destroy_closure)) {
66 DCHECK(on_destroy_closure_);
67 }
68
69 protected:
~IOBufferWithDestructionCallback()70 ~IOBufferWithDestructionCallback() override {
71 std::move(on_destroy_closure_).Run();
72 }
73
74 base::OnceClosure on_destroy_closure_;
75 };
76
77 class TestSocketPerformanceWatcher : public SocketPerformanceWatcher {
78 public:
TestSocketPerformanceWatcher(bool should_notify_updated_rtt)79 explicit TestSocketPerformanceWatcher(bool should_notify_updated_rtt)
80 : should_notify_updated_rtt_(should_notify_updated_rtt) {}
81
82 TestSocketPerformanceWatcher(const TestSocketPerformanceWatcher&) = delete;
83 TestSocketPerformanceWatcher& operator=(const TestSocketPerformanceWatcher&) =
84 delete;
85
86 ~TestSocketPerformanceWatcher() override = default;
87
ShouldNotifyUpdatedRTT() const88 bool ShouldNotifyUpdatedRTT() const override {
89 return should_notify_updated_rtt_;
90 }
91
OnUpdatedRTTAvailable(const base::TimeDelta & rtt)92 void OnUpdatedRTTAvailable(const base::TimeDelta& rtt) override {
93 rtt_notification_count_++;
94 }
95
OnConnectionChanged()96 void OnConnectionChanged() override { connection_changed_count_++; }
97
rtt_notification_count() const98 size_t rtt_notification_count() const { return rtt_notification_count_; }
99
connection_changed_count() const100 size_t connection_changed_count() const { return connection_changed_count_; }
101
102 private:
103 const bool should_notify_updated_rtt_;
104 size_t connection_changed_count_ = 0u;
105 size_t rtt_notification_count_ = 0u;
106 };
107
108 const int kListenBacklog = 5;
109
110 class TCPSocketTest : public PlatformTest, public WithTaskEnvironment {
111 protected:
TCPSocketTest()112 TCPSocketTest() : socket_(nullptr, nullptr, NetLogSource()) {}
113
SetUpListenIPv4()114 void SetUpListenIPv4() {
115 ASSERT_THAT(socket_.Open(ADDRESS_FAMILY_IPV4), IsOk());
116 ASSERT_THAT(socket_.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
117 IsOk());
118 ASSERT_THAT(socket_.Listen(kListenBacklog), IsOk());
119 ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
120 }
121
SetUpListenIPv6(bool * success)122 void SetUpListenIPv6(bool* success) {
123 *success = false;
124
125 if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK ||
126 socket_.Bind(IPEndPoint(IPAddress::IPv6Localhost(), 0)) != OK ||
127 socket_.Listen(kListenBacklog) != OK) {
128 LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
129 "disabled. Skipping the test";
130 return;
131 }
132 ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
133 *success = true;
134 }
135
TestAcceptAsync()136 void TestAcceptAsync() {
137 TestCompletionCallback accept_callback;
138 std::unique_ptr<TCPSocket> accepted_socket;
139 IPEndPoint accepted_address;
140 ASSERT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
141 accept_callback.callback()),
142 IsError(ERR_IO_PENDING));
143
144 TestCompletionCallback connect_callback;
145 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
146 nullptr, NetLogSource());
147 int connect_result = connecting_socket.Connect(connect_callback.callback());
148 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
149
150 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
151
152 EXPECT_TRUE(accepted_socket.get());
153
154 // Both sockets should be on the loopback network interface.
155 EXPECT_EQ(accepted_address.address(), local_address_.address());
156 }
157
158 #if defined(TCP_INFO) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
159 // Tests that notifications to Socket Performance Watcher (SPW) are delivered
160 // correctly. |should_notify_updated_rtt| is true if the SPW is interested in
161 // receiving RTT notifications. |num_messages| is the number of messages that
162 // are written/read by the sockets. |expect_connection_changed_count| is the
163 // expected number of connection change notifications received by the SPW.
164 // |expect_rtt_notification_count| is the expected number of RTT
165 // notifications received by the SPW. This test works by writing
166 // |num_messages| to the socket. A different socket (with a SPW attached to
167 // it) reads the messages.
TestSPWNotifications(bool should_notify_updated_rtt,size_t num_messages,size_t expect_connection_changed_count,size_t expect_rtt_notification_count)168 void TestSPWNotifications(bool should_notify_updated_rtt,
169 size_t num_messages,
170 size_t expect_connection_changed_count,
171 size_t expect_rtt_notification_count) {
172 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
173
174 TestCompletionCallback connect_callback;
175
176 auto watcher = std::make_unique<TestSocketPerformanceWatcher>(
177 should_notify_updated_rtt);
178 TestSocketPerformanceWatcher* watcher_ptr = watcher.get();
179
180 TCPSocket connecting_socket(std::move(watcher), nullptr, NetLogSource());
181
182 int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4);
183 ASSERT_THAT(result, IsOk());
184 int connect_result =
185 connecting_socket.Connect(local_address_, connect_callback.callback());
186
187 TestCompletionCallback accept_callback;
188 std::unique_ptr<TCPSocket> accepted_socket;
189 IPEndPoint accepted_address;
190 result = socket_.Accept(&accepted_socket, &accepted_address,
191 accept_callback.callback());
192 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
193
194 ASSERT_TRUE(accepted_socket.get());
195
196 // Both sockets should be on the loopback network interface.
197 EXPECT_EQ(accepted_address.address(), local_address_.address());
198
199 ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
200
201 for (size_t i = 0; i < num_messages; ++i) {
202 // Use a 1 byte message so that the watcher is notified at most once per
203 // message.
204 const std::string message("t");
205
206 scoped_refptr<IOBufferWithSize> write_buffer =
207 base::MakeRefCounted<IOBufferWithSize>(message.size());
208 memmove(write_buffer->data(), message.data(), message.size());
209
210 TestCompletionCallback write_callback;
211 int write_result = accepted_socket->Write(
212 write_buffer.get(), write_buffer->size(), write_callback.callback(),
213 TRAFFIC_ANNOTATION_FOR_TESTS);
214
215 scoped_refptr<IOBufferWithSize> read_buffer =
216 base::MakeRefCounted<IOBufferWithSize>(message.size());
217 TestCompletionCallback read_callback;
218 int read_result = connecting_socket.Read(
219 read_buffer.get(), read_buffer->size(), read_callback.callback());
220
221 ASSERT_EQ(1, write_callback.GetResult(write_result));
222 ASSERT_EQ(1, read_callback.GetResult(read_result));
223 }
224 EXPECT_EQ(expect_connection_changed_count,
225 watcher_ptr->connection_changed_count());
226 EXPECT_EQ(expect_rtt_notification_count,
227 watcher_ptr->rtt_notification_count());
228 }
229 #endif // defined(TCP_INFO) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
230
local_address_list() const231 AddressList local_address_list() const {
232 return AddressList(local_address_);
233 }
234
235 TCPSocket socket_;
236 IPEndPoint local_address_;
237 };
238
239 // Test listening and accepting with a socket bound to an IPv4 address.
TEST_F(TCPSocketTest,Accept)240 TEST_F(TCPSocketTest, Accept) {
241 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
242
243 TestCompletionCallback connect_callback;
244 // TODO(yzshen): Switch to use TCPSocket when it supports client socket
245 // operations.
246 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
247 nullptr, NetLogSource());
248 int connect_result = connecting_socket.Connect(connect_callback.callback());
249
250 TestCompletionCallback accept_callback;
251 std::unique_ptr<TCPSocket> accepted_socket;
252 IPEndPoint accepted_address;
253 int result = socket_.Accept(&accepted_socket, &accepted_address,
254 accept_callback.callback());
255 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
256
257 EXPECT_TRUE(accepted_socket.get());
258
259 // Both sockets should be on the loopback network interface.
260 EXPECT_EQ(accepted_address.address(), local_address_.address());
261
262 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
263 }
264
265 // Test Accept() callback.
TEST_F(TCPSocketTest,AcceptAsync)266 TEST_F(TCPSocketTest, AcceptAsync) {
267 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
268 TestAcceptAsync();
269 }
270
271 // Test AdoptConnectedSocket()
TEST_F(TCPSocketTest,AdoptConnectedSocket)272 TEST_F(TCPSocketTest, AdoptConnectedSocket) {
273 TCPSocket accepting_socket(nullptr, nullptr, NetLogSource());
274 ASSERT_THAT(accepting_socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
275 ASSERT_THAT(accepting_socket.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
276 IsOk());
277 ASSERT_THAT(accepting_socket.GetLocalAddress(&local_address_), IsOk());
278 ASSERT_THAT(accepting_socket.Listen(kListenBacklog), IsOk());
279
280 TestCompletionCallback connect_callback;
281 // TODO(yzshen): Switch to use TCPSocket when it supports client socket
282 // operations.
283 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
284 nullptr, NetLogSource());
285 int connect_result = connecting_socket.Connect(connect_callback.callback());
286
287 TestCompletionCallback accept_callback;
288 std::unique_ptr<TCPSocket> accepted_socket;
289 IPEndPoint accepted_address;
290 int result = accepting_socket.Accept(&accepted_socket, &accepted_address,
291 accept_callback.callback());
292 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
293
294 SocketDescriptor accepted_descriptor =
295 accepted_socket->ReleaseSocketDescriptorForTesting();
296
297 ASSERT_THAT(
298 socket_.AdoptConnectedSocket(accepted_descriptor, accepted_address),
299 IsOk());
300
301 // socket_ should now have the local address.
302 IPEndPoint adopted_address;
303 ASSERT_THAT(socket_.GetLocalAddress(&adopted_address), IsOk());
304 EXPECT_EQ(local_address_.address(), adopted_address.address());
305
306 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
307 }
308
309 // Test Accept() for AdoptUnconnectedSocket.
TEST_F(TCPSocketTest,AcceptForAdoptedUnconnectedSocket)310 TEST_F(TCPSocketTest, AcceptForAdoptedUnconnectedSocket) {
311 SocketDescriptor existing_socket =
312 CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
313 ASSERT_THAT(socket_.AdoptUnconnectedSocket(existing_socket), IsOk());
314
315 IPEndPoint address(IPAddress::IPv4Localhost(), 0);
316 SockaddrStorage storage;
317 ASSERT_TRUE(address.ToSockAddr(storage.addr, &storage.addr_len));
318 ASSERT_EQ(0, bind(existing_socket, storage.addr, storage.addr_len));
319
320 ASSERT_THAT(socket_.Listen(kListenBacklog), IsOk());
321 ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
322
323 TestAcceptAsync();
324 }
325
326 // Accept two connections simultaneously.
TEST_F(TCPSocketTest,Accept2Connections)327 TEST_F(TCPSocketTest, Accept2Connections) {
328 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
329
330 TestCompletionCallback accept_callback;
331 std::unique_ptr<TCPSocket> accepted_socket;
332 IPEndPoint accepted_address;
333
334 ASSERT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
335 accept_callback.callback()),
336 IsError(ERR_IO_PENDING));
337
338 TestCompletionCallback connect_callback;
339 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
340 nullptr, NetLogSource());
341 int connect_result = connecting_socket.Connect(connect_callback.callback());
342
343 TestCompletionCallback connect_callback2;
344 TCPClientSocket connecting_socket2(local_address_list(), nullptr, nullptr,
345 nullptr, NetLogSource());
346 int connect_result2 =
347 connecting_socket2.Connect(connect_callback2.callback());
348
349 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
350
351 TestCompletionCallback accept_callback2;
352 std::unique_ptr<TCPSocket> accepted_socket2;
353 IPEndPoint accepted_address2;
354
355 int result = socket_.Accept(&accepted_socket2, &accepted_address2,
356 accept_callback2.callback());
357 ASSERT_THAT(accept_callback2.GetResult(result), IsOk());
358
359 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
360 EXPECT_THAT(connect_callback2.GetResult(connect_result2), IsOk());
361
362 EXPECT_TRUE(accepted_socket.get());
363 EXPECT_TRUE(accepted_socket2.get());
364 EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
365
366 EXPECT_EQ(accepted_address.address(), local_address_.address());
367 EXPECT_EQ(accepted_address2.address(), local_address_.address());
368 }
369
370 // Test listening and accepting with a socket bound to an IPv6 address.
TEST_F(TCPSocketTest,AcceptIPv6)371 TEST_F(TCPSocketTest, AcceptIPv6) {
372 bool initialized = false;
373 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv6(&initialized));
374 if (!initialized)
375 return;
376
377 TestCompletionCallback connect_callback;
378 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
379 nullptr, NetLogSource());
380 int connect_result = connecting_socket.Connect(connect_callback.callback());
381
382 TestCompletionCallback accept_callback;
383 std::unique_ptr<TCPSocket> accepted_socket;
384 IPEndPoint accepted_address;
385 int result = socket_.Accept(&accepted_socket, &accepted_address,
386 accept_callback.callback());
387 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
388
389 EXPECT_TRUE(accepted_socket.get());
390
391 // Both sockets should be on the loopback network interface.
392 EXPECT_EQ(accepted_address.address(), local_address_.address());
393
394 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
395 }
396
TEST_F(TCPSocketTest,ReadWrite)397 TEST_F(TCPSocketTest, ReadWrite) {
398 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
399
400 TestCompletionCallback connect_callback;
401 TCPSocket connecting_socket(nullptr, nullptr, NetLogSource());
402 int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4);
403 ASSERT_THAT(result, IsOk());
404 int connect_result =
405 connecting_socket.Connect(local_address_, connect_callback.callback());
406
407 TestCompletionCallback accept_callback;
408 std::unique_ptr<TCPSocket> accepted_socket;
409 IPEndPoint accepted_address;
410 result = socket_.Accept(&accepted_socket, &accepted_address,
411 accept_callback.callback());
412 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
413
414 ASSERT_TRUE(accepted_socket.get());
415
416 // Both sockets should be on the loopback network interface.
417 EXPECT_EQ(accepted_address.address(), local_address_.address());
418
419 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
420
421 const std::string message("test message");
422 std::vector<char> buffer(message.size());
423
424 size_t bytes_written = 0;
425 while (bytes_written < message.size()) {
426 scoped_refptr<IOBufferWithSize> write_buffer =
427 base::MakeRefCounted<IOBufferWithSize>(message.size() - bytes_written);
428 memmove(write_buffer->data(), message.data() + bytes_written,
429 message.size() - bytes_written);
430
431 TestCompletionCallback write_callback;
432 int write_result = accepted_socket->Write(
433 write_buffer.get(), write_buffer->size(), write_callback.callback(),
434 TRAFFIC_ANNOTATION_FOR_TESTS);
435 write_result = write_callback.GetResult(write_result);
436 ASSERT_TRUE(write_result >= 0);
437 bytes_written += write_result;
438 ASSERT_TRUE(bytes_written <= message.size());
439 }
440
441 size_t bytes_read = 0;
442 while (bytes_read < message.size()) {
443 scoped_refptr<IOBufferWithSize> read_buffer =
444 base::MakeRefCounted<IOBufferWithSize>(message.size() - bytes_read);
445 TestCompletionCallback read_callback;
446 int read_result = connecting_socket.Read(
447 read_buffer.get(), read_buffer->size(), read_callback.callback());
448 read_result = read_callback.GetResult(read_result);
449 ASSERT_TRUE(read_result >= 0);
450 ASSERT_TRUE(bytes_read + read_result <= message.size());
451 memmove(&buffer[bytes_read], read_buffer->data(), read_result);
452 bytes_read += read_result;
453 }
454
455 std::string received_message(buffer.begin(), buffer.end());
456 ASSERT_EQ(message, received_message);
457 }
458
459 // Destroy a TCPSocket while there's a pending read, and make sure the read
460 // IOBuffer that the socket was holding on to is destroyed.
461 // See https://crbug.com/804868.
TEST_F(TCPSocketTest,DestroyWithPendingRead)462 TEST_F(TCPSocketTest, DestroyWithPendingRead) {
463 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
464
465 // Create a connected socket.
466
467 TestCompletionCallback connect_callback;
468 std::unique_ptr<TCPSocket> connecting_socket =
469 std::make_unique<TCPSocket>(nullptr, nullptr, NetLogSource());
470 int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
471 ASSERT_THAT(result, IsOk());
472 int connect_result =
473 connecting_socket->Connect(local_address_, connect_callback.callback());
474
475 TestCompletionCallback accept_callback;
476 std::unique_ptr<TCPSocket> accepted_socket;
477 IPEndPoint accepted_address;
478 result = socket_.Accept(&accepted_socket, &accepted_address,
479 accept_callback.callback());
480 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
481 ASSERT_TRUE(accepted_socket.get());
482 ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
483
484 // Try to read from the socket, but never write anything to the other end.
485 base::RunLoop run_loop;
486 scoped_refptr<IOBufferWithDestructionCallback> read_buffer(
487 base::MakeRefCounted<IOBufferWithDestructionCallback>(
488 run_loop.QuitClosure()));
489 TestCompletionCallback read_callback;
490 EXPECT_EQ(ERR_IO_PENDING,
491 connecting_socket->Read(read_buffer.get(), read_buffer->size(),
492 read_callback.callback()));
493
494 // Release the handle to the read buffer and destroy the socket. Make sure the
495 // read buffer is destroyed.
496 read_buffer = nullptr;
497 connecting_socket.reset();
498 run_loop.Run();
499 }
500
501 // Destroy a TCPSocket while there's a pending write, and make sure the write
502 // IOBuffer that the socket was holding on to is destroyed.
TEST_F(TCPSocketTest,DestroyWithPendingWrite)503 TEST_F(TCPSocketTest, DestroyWithPendingWrite) {
504 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
505
506 // Create a connected socket.
507
508 TestCompletionCallback connect_callback;
509 std::unique_ptr<TCPSocket> connecting_socket =
510 std::make_unique<TCPSocket>(nullptr, nullptr, NetLogSource());
511 int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
512 ASSERT_THAT(result, IsOk());
513 int connect_result =
514 connecting_socket->Connect(local_address_, connect_callback.callback());
515
516 TestCompletionCallback accept_callback;
517 std::unique_ptr<TCPSocket> accepted_socket;
518 IPEndPoint accepted_address;
519 result = socket_.Accept(&accepted_socket, &accepted_address,
520 accept_callback.callback());
521 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
522 ASSERT_TRUE(accepted_socket.get());
523 ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
524
525 // Repeatedly write to the socket until an operation does not complete
526 // synchronously.
527 base::RunLoop run_loop;
528 scoped_refptr<IOBufferWithDestructionCallback> write_buffer(
529 base::MakeRefCounted<IOBufferWithDestructionCallback>(
530 run_loop.QuitClosure()));
531 memset(write_buffer->data(), '1', write_buffer->size());
532 TestCompletionCallback write_callback;
533 while (true) {
534 result = connecting_socket->Write(write_buffer.get(), write_buffer->size(),
535 write_callback.callback(),
536 TRAFFIC_ANNOTATION_FOR_TESTS);
537 if (result == ERR_IO_PENDING)
538 break;
539 ASSERT_LT(0, result);
540 }
541
542 // Release the handle to the read buffer and destroy the socket. Make sure the
543 // write buffer is destroyed.
544 write_buffer = nullptr;
545 connecting_socket.reset();
546 run_loop.Run();
547 }
548
549 // If a ReadIfReady is pending, it's legal to cancel it and start reading later.
TEST_F(TCPSocketTest,CancelPendingReadIfReady)550 TEST_F(TCPSocketTest, CancelPendingReadIfReady) {
551 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
552
553 // Create a connected socket.
554 TestCompletionCallback connect_callback;
555 std::unique_ptr<TCPSocket> connecting_socket =
556 std::make_unique<TCPSocket>(nullptr, nullptr, NetLogSource());
557 int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
558 ASSERT_THAT(result, IsOk());
559 int connect_result =
560 connecting_socket->Connect(local_address_, connect_callback.callback());
561
562 TestCompletionCallback accept_callback;
563 std::unique_ptr<TCPSocket> accepted_socket;
564 IPEndPoint accepted_address;
565 result = socket_.Accept(&accepted_socket, &accepted_address,
566 accept_callback.callback());
567 ASSERT_THAT(accept_callback.GetResult(result), IsOk());
568 ASSERT_TRUE(accepted_socket.get());
569 ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
570
571 // Try to read from the socket, but never write anything to the other end.
572 base::RunLoop run_loop;
573 scoped_refptr<IOBufferWithDestructionCallback> read_buffer(
574 base::MakeRefCounted<IOBufferWithDestructionCallback>(
575 run_loop.QuitClosure()));
576 TestCompletionCallback read_callback;
577 EXPECT_EQ(ERR_IO_PENDING, connecting_socket->ReadIfReady(
578 read_buffer.get(), read_buffer->size(),
579 read_callback.callback()));
580
581 // Now cancel the pending ReadIfReady().
582 connecting_socket->CancelReadIfReady();
583
584 // Send data to |connecting_socket|.
585 const char kMsg[] = "hello!";
586 scoped_refptr<StringIOBuffer> write_buffer =
587 base::MakeRefCounted<StringIOBuffer>(kMsg);
588
589 TestCompletionCallback write_callback;
590 int write_result = accepted_socket->Write(write_buffer.get(), strlen(kMsg),
591 write_callback.callback(),
592 TRAFFIC_ANNOTATION_FOR_TESTS);
593 const int msg_size = strlen(kMsg);
594 ASSERT_EQ(msg_size, write_result);
595
596 TestCompletionCallback read_callback2;
597 int read_result = connecting_socket->ReadIfReady(
598 read_buffer.get(), read_buffer->size(), read_callback2.callback());
599 if (read_result == ERR_IO_PENDING) {
600 ASSERT_EQ(OK, read_callback2.GetResult(read_result));
601 read_result = connecting_socket->ReadIfReady(
602 read_buffer.get(), read_buffer->size(), read_callback2.callback());
603 }
604
605 ASSERT_EQ(msg_size, read_result);
606 ASSERT_EQ(0, memcmp(&kMsg, read_buffer->data(), msg_size));
607 }
608
TEST_F(TCPSocketTest,IsConnected)609 TEST_F(TCPSocketTest, IsConnected) {
610 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
611
612 TestCompletionCallback accept_callback;
613 std::unique_ptr<TCPSocket> accepted_socket;
614 IPEndPoint accepted_address;
615 EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
616 accept_callback.callback()),
617 IsError(ERR_IO_PENDING));
618
619 TestCompletionCallback connect_callback;
620 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
621 nullptr, NetLogSource());
622
623 // Immediately after creation, the socket should not be connected.
624 EXPECT_FALSE(connecting_socket.IsConnected());
625 EXPECT_FALSE(connecting_socket.IsConnectedAndIdle());
626
627 int connect_result = connecting_socket.Connect(connect_callback.callback());
628 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
629 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
630
631 // |connecting_socket| and |accepted_socket| should now both be reported as
632 // connected, and idle
633 EXPECT_TRUE(accepted_socket->IsConnected());
634 EXPECT_TRUE(accepted_socket->IsConnectedAndIdle());
635 EXPECT_TRUE(connecting_socket.IsConnected());
636 EXPECT_TRUE(connecting_socket.IsConnectedAndIdle());
637
638 // Write one byte to the |accepted_socket|, then close it.
639 const char kSomeData[] = "!";
640 scoped_refptr<IOBuffer> some_data_buffer =
641 base::MakeRefCounted<StringIOBuffer>(kSomeData);
642 TestCompletionCallback write_callback;
643 EXPECT_THAT(write_callback.GetResult(accepted_socket->Write(
644 some_data_buffer.get(), 1, write_callback.callback(),
645 TRAFFIC_ANNOTATION_FOR_TESTS)),
646 1);
647 accepted_socket.reset();
648
649 // Wait until |connecting_socket| is signalled as having data to read.
650 fd_set read_fds;
651 FD_ZERO(&read_fds);
652 SocketDescriptor connecting_fd =
653 connecting_socket.SocketDescriptorForTesting();
654 FD_SET(connecting_fd, &read_fds);
655 ASSERT_EQ(select(FD_SETSIZE, &read_fds, nullptr, nullptr, nullptr), 1);
656 ASSERT_TRUE(FD_ISSET(connecting_fd, &read_fds));
657
658 // It should now be reported as connected, but not as idle.
659 EXPECT_TRUE(connecting_socket.IsConnected());
660 EXPECT_FALSE(connecting_socket.IsConnectedAndIdle());
661
662 // Read the message from |connecting_socket_|, then read the end-of-stream.
663 scoped_refptr<IOBufferWithSize> read_buffer =
664 base::MakeRefCounted<IOBufferWithSize>(2);
665 TestCompletionCallback read_callback;
666 EXPECT_THAT(
667 read_callback.GetResult(connecting_socket.Read(
668 read_buffer.get(), read_buffer->size(), read_callback.callback())),
669 1);
670 EXPECT_THAT(
671 read_callback.GetResult(connecting_socket.Read(
672 read_buffer.get(), read_buffer->size(), read_callback.callback())),
673 0);
674
675 // |connecting_socket| has no more data to read, so should noe be reported
676 // as disconnected.
677 EXPECT_FALSE(connecting_socket.IsConnected());
678 EXPECT_FALSE(connecting_socket.IsConnectedAndIdle());
679 }
680
681 // Tests that setting a socket option in the BeforeConnectCallback works. With
682 // real sockets, socket options often have to be set before the connect() call,
683 // and the BeforeConnectCallback is the only way to do that, with a
684 // TCPClientSocket.
TEST_F(TCPSocketTest,BeforeConnectCallback)685 TEST_F(TCPSocketTest, BeforeConnectCallback) {
686 // A receive buffer size that is between max and minimum buffer size limits,
687 // and weird enough to likely not be a default value.
688 const int kReceiveBufferSize = 32 * 1024 + 1117;
689 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
690
691 TestCompletionCallback accept_callback;
692 std::unique_ptr<TCPSocket> accepted_socket;
693 IPEndPoint accepted_address;
694 EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
695 accept_callback.callback()),
696 IsError(ERR_IO_PENDING));
697
698 TestCompletionCallback connect_callback;
699 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
700 nullptr, NetLogSource());
701
702 connecting_socket.SetBeforeConnectCallback(base::BindLambdaForTesting([&] {
703 EXPECT_FALSE(connecting_socket.IsConnected());
704 int result = connecting_socket.SetReceiveBufferSize(kReceiveBufferSize);
705 EXPECT_THAT(result, IsOk());
706 return result;
707 }));
708 int connect_result = connecting_socket.Connect(connect_callback.callback());
709
710 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
711 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
712
713 int actual_size = 0;
714 socklen_t actual_size_len = sizeof(actual_size);
715 int os_result = getsockopt(
716 connecting_socket.SocketDescriptorForTesting(), SOL_SOCKET, SO_RCVBUF,
717 reinterpret_cast<char*>(&actual_size), &actual_size_len);
718 ASSERT_EQ(0, os_result);
719 // Linux platforms generally allocate twice as much buffer size is requested to
720 // account for internal kernel data structures.
721 #if BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS) || BUILDFLAG(IS_ANDROID)
722 EXPECT_EQ(2 * kReceiveBufferSize, actual_size);
723 // Unfortunately, Apple platform behavior doesn't seem to be documented, and
724 // doesn't match behavior on any other platforms.
725 // Fuchsia doesn't currently implement SO_RCVBUF.
726 #elif !BUILDFLAG(IS_APPLE) && !BUILDFLAG(IS_FUCHSIA)
727 EXPECT_EQ(kReceiveBufferSize, actual_size);
728 #endif
729 }
730
TEST_F(TCPSocketTest,BeforeConnectCallbackFails)731 TEST_F(TCPSocketTest, BeforeConnectCallbackFails) {
732 // Setting up a server isn't strictly necessary, but it does allow checking
733 // the server was never connected to.
734 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
735
736 TestCompletionCallback accept_callback;
737 std::unique_ptr<TCPSocket> accepted_socket;
738 IPEndPoint accepted_address;
739 EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
740 accept_callback.callback()),
741 IsError(ERR_IO_PENDING));
742
743 TestCompletionCallback connect_callback;
744 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
745 nullptr, NetLogSource());
746
747 // Set a callback that returns a nonsensical error, and make sure it's
748 // returned.
749 connecting_socket.SetBeforeConnectCallback(base::BindRepeating(
750 [] { return static_cast<int>(net::ERR_NAME_NOT_RESOLVED); }));
751 int connect_result = connecting_socket.Connect(connect_callback.callback());
752 EXPECT_THAT(connect_callback.GetResult(connect_result),
753 IsError(net::ERR_NAME_NOT_RESOLVED));
754
755 // Best effort check that the socket wasn't accepted - may flakily pass on
756 // regression, unfortunately.
757 base::RunLoop().RunUntilIdle();
758 EXPECT_FALSE(accept_callback.have_result());
759 }
760
TEST_F(TCPSocketTest,SetKeepAlive)761 TEST_F(TCPSocketTest, SetKeepAlive) {
762 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
763
764 TestCompletionCallback accept_callback;
765 std::unique_ptr<TCPSocket> accepted_socket;
766 IPEndPoint accepted_address;
767 EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
768 accept_callback.callback()),
769 IsError(ERR_IO_PENDING));
770
771 TestCompletionCallback connect_callback;
772 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
773 nullptr, NetLogSource());
774
775 // Non-connected sockets should not be able to set KeepAlive.
776 ASSERT_FALSE(connecting_socket.IsConnected());
777 EXPECT_FALSE(
778 connecting_socket.SetKeepAlive(true /* enable */, 14 /* delay */));
779
780 // Connect.
781 int connect_result = connecting_socket.Connect(connect_callback.callback());
782 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
783 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
784
785 // Connected sockets should be able to enable and disable KeepAlive.
786 ASSERT_TRUE(connecting_socket.IsConnected());
787 EXPECT_TRUE(
788 connecting_socket.SetKeepAlive(true /* enable */, 22 /* delay */));
789 EXPECT_TRUE(
790 connecting_socket.SetKeepAlive(false /* enable */, 3 /* delay */));
791 }
792
TEST_F(TCPSocketTest,SetNoDelay)793 TEST_F(TCPSocketTest, SetNoDelay) {
794 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
795
796 TestCompletionCallback accept_callback;
797 std::unique_ptr<TCPSocket> accepted_socket;
798 IPEndPoint accepted_address;
799 EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
800 accept_callback.callback()),
801 IsError(ERR_IO_PENDING));
802
803 TestCompletionCallback connect_callback;
804 TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
805 nullptr, NetLogSource());
806
807 // Non-connected sockets should not be able to set NoDelay.
808 ASSERT_FALSE(connecting_socket.IsConnected());
809 EXPECT_FALSE(connecting_socket.SetNoDelay(true /* no_delay */));
810
811 // Connect.
812 int connect_result = connecting_socket.Connect(connect_callback.callback());
813 EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
814 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
815
816 // Connected sockets should be able to enable and disable NoDelay.
817 ASSERT_TRUE(connecting_socket.IsConnected());
818 EXPECT_TRUE(connecting_socket.SetNoDelay(true /* no_delay */));
819 EXPECT_TRUE(connecting_socket.SetNoDelay(false /* no_delay */));
820 }
821
822 // These tests require kernel support for tcp_info struct, and so they are
823 // enabled only on certain platforms.
824 #if defined(TCP_INFO) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
825 // If SocketPerformanceWatcher::ShouldNotifyUpdatedRTT always returns false,
826 // then the wtatcher should not receive any notifications.
TEST_F(TCPSocketTest,SPWNotInterested)827 TEST_F(TCPSocketTest, SPWNotInterested) {
828 TestSPWNotifications(false, 2u, 0u, 0u);
829 }
830
831 // One notification should be received when the socket connects. One
832 // additional notification should be received for each message read.
TEST_F(TCPSocketTest,SPWNoAdvance)833 TEST_F(TCPSocketTest, SPWNoAdvance) {
834 TestSPWNotifications(true, 2u, 0u, 3u);
835 }
836 #endif // defined(TCP_INFO) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
837
838 // On Android, where socket tagging is supported, verify that TCPSocket::Tag
839 // works as expected.
840 #if BUILDFLAG(IS_ANDROID)
TEST_F(TCPSocketTest,Tag)841 TEST_F(TCPSocketTest, Tag) {
842 if (!CanGetTaggedBytes()) {
843 DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
844 return;
845 }
846
847 // Start test server.
848 EmbeddedTestServer test_server;
849 test_server.AddDefaultHandlers(base::FilePath());
850 ASSERT_TRUE(test_server.Start());
851
852 AddressList addr_list;
853 ASSERT_TRUE(test_server.GetAddressList(&addr_list));
854 EXPECT_EQ(socket_.Open(addr_list[0].GetFamily()), OK);
855
856 // Verify TCP connect packets are tagged and counted properly.
857 int32_t tag_val1 = 0x12345678;
858 uint64_t old_traffic = GetTaggedBytes(tag_val1);
859 SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
860 socket_.ApplySocketTag(tag1);
861 TestCompletionCallback connect_callback;
862 int connect_result =
863 socket_.Connect(addr_list[0], connect_callback.callback());
864 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
865 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
866
867 // Verify socket can be retagged with a new value and the current process's
868 // UID.
869 int32_t tag_val2 = 0x87654321;
870 old_traffic = GetTaggedBytes(tag_val2);
871 SocketTag tag2(getuid(), tag_val2);
872 socket_.ApplySocketTag(tag2);
873 const char kRequest1[] = "GET / HTTP/1.0";
874 scoped_refptr<IOBuffer> write_buffer1 =
875 base::MakeRefCounted<StringIOBuffer>(kRequest1);
876 TestCompletionCallback write_callback1;
877 EXPECT_EQ(
878 socket_.Write(write_buffer1.get(), strlen(kRequest1),
879 write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
880 static_cast<int>(strlen(kRequest1)));
881 EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
882
883 // Verify socket can be retagged with a new value and the current process's
884 // UID.
885 old_traffic = GetTaggedBytes(tag_val1);
886 socket_.ApplySocketTag(tag1);
887 const char kRequest2[] = "\n\n";
888 scoped_refptr<IOBuffer> write_buffer2 =
889 base::MakeRefCounted<StringIOBuffer>(kRequest2);
890 TestCompletionCallback write_callback2;
891 EXPECT_EQ(
892 socket_.Write(write_buffer2.get(), strlen(kRequest2),
893 write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
894 static_cast<int>(strlen(kRequest2)));
895 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
896
897 socket_.Close();
898 }
899
TEST_F(TCPSocketTest,TagAfterConnect)900 TEST_F(TCPSocketTest, TagAfterConnect) {
901 if (!CanGetTaggedBytes()) {
902 DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
903 return;
904 }
905
906 // Start test server.
907 EmbeddedTestServer test_server;
908 test_server.AddDefaultHandlers(base::FilePath());
909 ASSERT_TRUE(test_server.Start());
910
911 AddressList addr_list;
912 ASSERT_TRUE(test_server.GetAddressList(&addr_list));
913 EXPECT_EQ(socket_.Open(addr_list[0].GetFamily()), OK);
914
915 // Connect socket.
916 TestCompletionCallback connect_callback;
917 int connect_result =
918 socket_.Connect(addr_list[0], connect_callback.callback());
919 EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
920
921 // Verify socket can be tagged with a new value and the current process's
922 // UID.
923 int32_t tag_val2 = 0x87654321;
924 uint64_t old_traffic = GetTaggedBytes(tag_val2);
925 SocketTag tag2(getuid(), tag_val2);
926 socket_.ApplySocketTag(tag2);
927 const char kRequest1[] = "GET / HTTP/1.0";
928 scoped_refptr<IOBuffer> write_buffer1 =
929 base::MakeRefCounted<StringIOBuffer>(kRequest1);
930 TestCompletionCallback write_callback1;
931 EXPECT_EQ(
932 socket_.Write(write_buffer1.get(), strlen(kRequest1),
933 write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
934 static_cast<int>(strlen(kRequest1)));
935 EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
936
937 // Verify socket can be retagged with a new value and the current process's
938 // UID.
939 int32_t tag_val1 = 0x12345678;
940 old_traffic = GetTaggedBytes(tag_val1);
941 SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
942 socket_.ApplySocketTag(tag1);
943 const char kRequest2[] = "\n\n";
944 scoped_refptr<IOBuffer> write_buffer2 =
945 base::MakeRefCounted<StringIOBuffer>(kRequest2);
946 TestCompletionCallback write_callback2;
947 EXPECT_EQ(
948 socket_.Write(write_buffer2.get(), strlen(kRequest2),
949 write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
950 static_cast<int>(strlen(kRequest2)));
951 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
952
953 socket_.Close();
954 }
955
TEST_F(TCPSocketTest,BindToNetwork)956 TEST_F(TCPSocketTest, BindToNetwork) {
957 NetworkChangeNotifierFactoryAndroid ncn_factory;
958 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
959 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
960 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
961 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
962
963 const handles::NetworkHandle wrong_network_handle = 65536;
964 // Try binding to this IP to trigger the underlying BindToNetwork call.
965 const IPEndPoint ip(IPAddress::IPv4Localhost(), 0);
966 // TestCompletionCallback connect_callback;
967 TCPClientSocket wrong_socket(local_address_list(), nullptr, nullptr, nullptr,
968 NetLogSource(), wrong_network_handle);
969 // Different Android versions might report different errors. Hence, just check
970 // what shouldn't happen.
971 int rv = wrong_socket.Bind(ip);
972 EXPECT_NE(OK, rv);
973 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
974
975 // Connecting using an existing network should succeed.
976 const handles::NetworkHandle network_handle =
977 NetworkChangeNotifier::GetDefaultNetwork();
978 if (network_handle != handles::kInvalidNetworkHandle) {
979 TCPClientSocket correct_socket(local_address_list(), nullptr, nullptr,
980 nullptr, NetLogSource(), network_handle);
981 EXPECT_EQ(OK, correct_socket.Bind(ip));
982 }
983 }
984
985 #endif // BUILDFLAG(IS_ANDROID)
986
987 } // namespace
988 } // namespace net
989