xref: /aosp_15_r20/external/cronet/net/socket/tcp_socket_unittest.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/tcp_socket.h"
6 
7 #include <stddef.h>
8 #include <string.h>
9 
10 #include <memory>
11 #include <string>
12 #include <vector>
13 
14 #include "base/functional/bind.h"
15 #include "base/memory/ref_counted.h"
16 #include "base/test/bind.h"
17 #include "base/time/time.h"
18 #include "build/build_config.h"
19 #include "net/base/address_list.h"
20 #include "net/base/io_buffer.h"
21 #include "net/base/ip_endpoint.h"
22 #include "net/base/net_errors.h"
23 #include "net/base/sockaddr_storage.h"
24 #include "net/base/sys_addrinfo.h"
25 #include "net/base/test_completion_callback.h"
26 #include "net/log/net_log_source.h"
27 #include "net/socket/socket_descriptor.h"
28 #include "net/socket/socket_performance_watcher.h"
29 #include "net/socket/socket_test_util.h"
30 #include "net/socket/tcp_client_socket.h"
31 #include "net/test/embedded_test_server/embedded_test_server.h"
32 #include "net/test/gtest_util.h"
33 #include "net/test/test_with_task_environment.h"
34 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
35 #include "testing/gmock/include/gmock/gmock.h"
36 #include "testing/gtest/include/gtest/gtest.h"
37 #include "testing/platform_test.h"
38 
39 #if BUILDFLAG(IS_ANDROID)
40 #include "base/android/build_info.h"
41 #include "net/android/network_change_notifier_factory_android.h"
42 #include "net/base/network_change_notifier.h"
43 #endif  // BUILDFLAG(IS_ANDROID)
44 
45 // For getsockopt() call.
46 #if BUILDFLAG(IS_WIN)
47 #include <winsock2.h>
48 #else  // !BUILDFLAG(IS_WIN)
49 #include <sys/socket.h>
50 #endif  //  !BUILDFLAG(IS_WIN)
51 
52 using net::test::IsError;
53 using net::test::IsOk;
54 
55 namespace net {
56 
57 namespace {
58 
59 // IOBuffer with the ability to invoke a callback when destroyed. Useful for
60 // checking for leaks.
61 class IOBufferWithDestructionCallback : public IOBufferWithSize {
62  public:
IOBufferWithDestructionCallback(base::OnceClosure on_destroy_closure)63   explicit IOBufferWithDestructionCallback(base::OnceClosure on_destroy_closure)
64       : IOBufferWithSize(1024),
65         on_destroy_closure_(std::move(on_destroy_closure)) {
66     DCHECK(on_destroy_closure_);
67   }
68 
69  protected:
~IOBufferWithDestructionCallback()70   ~IOBufferWithDestructionCallback() override {
71     std::move(on_destroy_closure_).Run();
72   }
73 
74   base::OnceClosure on_destroy_closure_;
75 };
76 
77 class TestSocketPerformanceWatcher : public SocketPerformanceWatcher {
78  public:
TestSocketPerformanceWatcher(bool should_notify_updated_rtt)79   explicit TestSocketPerformanceWatcher(bool should_notify_updated_rtt)
80       : should_notify_updated_rtt_(should_notify_updated_rtt) {}
81 
82   TestSocketPerformanceWatcher(const TestSocketPerformanceWatcher&) = delete;
83   TestSocketPerformanceWatcher& operator=(const TestSocketPerformanceWatcher&) =
84       delete;
85 
86   ~TestSocketPerformanceWatcher() override = default;
87 
ShouldNotifyUpdatedRTT() const88   bool ShouldNotifyUpdatedRTT() const override {
89     return should_notify_updated_rtt_;
90   }
91 
OnUpdatedRTTAvailable(const base::TimeDelta & rtt)92   void OnUpdatedRTTAvailable(const base::TimeDelta& rtt) override {
93     rtt_notification_count_++;
94   }
95 
OnConnectionChanged()96   void OnConnectionChanged() override { connection_changed_count_++; }
97 
rtt_notification_count() const98   size_t rtt_notification_count() const { return rtt_notification_count_; }
99 
connection_changed_count() const100   size_t connection_changed_count() const { return connection_changed_count_; }
101 
102  private:
103   const bool should_notify_updated_rtt_;
104   size_t connection_changed_count_ = 0u;
105   size_t rtt_notification_count_ = 0u;
106 };
107 
108 const int kListenBacklog = 5;
109 
110 class TCPSocketTest : public PlatformTest, public WithTaskEnvironment {
111  protected:
TCPSocketTest()112   TCPSocketTest() : socket_(nullptr, nullptr, NetLogSource()) {}
113 
SetUpListenIPv4()114   void SetUpListenIPv4() {
115     ASSERT_THAT(socket_.Open(ADDRESS_FAMILY_IPV4), IsOk());
116     ASSERT_THAT(socket_.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
117                 IsOk());
118     ASSERT_THAT(socket_.Listen(kListenBacklog), IsOk());
119     ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
120   }
121 
SetUpListenIPv6(bool * success)122   void SetUpListenIPv6(bool* success) {
123     *success = false;
124 
125     if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK ||
126         socket_.Bind(IPEndPoint(IPAddress::IPv6Localhost(), 0)) != OK ||
127         socket_.Listen(kListenBacklog) != OK) {
128       LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
129           "disabled. Skipping the test";
130       return;
131     }
132     ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
133     *success = true;
134   }
135 
TestAcceptAsync()136   void TestAcceptAsync() {
137     TestCompletionCallback accept_callback;
138     std::unique_ptr<TCPSocket> accepted_socket;
139     IPEndPoint accepted_address;
140     ASSERT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
141                                accept_callback.callback()),
142                 IsError(ERR_IO_PENDING));
143 
144     TestCompletionCallback connect_callback;
145     TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
146                                       nullptr, NetLogSource());
147     int connect_result = connecting_socket.Connect(connect_callback.callback());
148     EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
149 
150     EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
151 
152     EXPECT_TRUE(accepted_socket.get());
153 
154     // Both sockets should be on the loopback network interface.
155     EXPECT_EQ(accepted_address.address(), local_address_.address());
156   }
157 
158 #if defined(TCP_INFO) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
159   // Tests that notifications to Socket Performance Watcher (SPW) are delivered
160   // correctly. |should_notify_updated_rtt| is true if the SPW is interested in
161   // receiving RTT notifications. |num_messages| is the number of messages that
162   // are written/read by the sockets. |expect_connection_changed_count| is the
163   // expected number of connection change notifications received by the SPW.
164   // |expect_rtt_notification_count| is the expected number of RTT
165   // notifications received by the SPW. This test works by writing
166   // |num_messages| to the socket. A different socket (with a SPW attached to
167   // it) reads the messages.
TestSPWNotifications(bool should_notify_updated_rtt,size_t num_messages,size_t expect_connection_changed_count,size_t expect_rtt_notification_count)168   void TestSPWNotifications(bool should_notify_updated_rtt,
169                             size_t num_messages,
170                             size_t expect_connection_changed_count,
171                             size_t expect_rtt_notification_count) {
172     ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
173 
174     TestCompletionCallback connect_callback;
175 
176     auto watcher = std::make_unique<TestSocketPerformanceWatcher>(
177         should_notify_updated_rtt);
178     TestSocketPerformanceWatcher* watcher_ptr = watcher.get();
179 
180     TCPSocket connecting_socket(std::move(watcher), nullptr, NetLogSource());
181 
182     int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4);
183     ASSERT_THAT(result, IsOk());
184     int connect_result =
185         connecting_socket.Connect(local_address_, connect_callback.callback());
186 
187     TestCompletionCallback accept_callback;
188     std::unique_ptr<TCPSocket> accepted_socket;
189     IPEndPoint accepted_address;
190     result = socket_.Accept(&accepted_socket, &accepted_address,
191                             accept_callback.callback());
192     ASSERT_THAT(accept_callback.GetResult(result), IsOk());
193 
194     ASSERT_TRUE(accepted_socket.get());
195 
196     // Both sockets should be on the loopback network interface.
197     EXPECT_EQ(accepted_address.address(), local_address_.address());
198 
199     ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
200 
201     for (size_t i = 0; i < num_messages; ++i) {
202       // Use a 1 byte message so that the watcher is notified at most once per
203       // message.
204       const std::string message("t");
205 
206       scoped_refptr<IOBufferWithSize> write_buffer =
207           base::MakeRefCounted<IOBufferWithSize>(message.size());
208       memmove(write_buffer->data(), message.data(), message.size());
209 
210       TestCompletionCallback write_callback;
211       int write_result = accepted_socket->Write(
212           write_buffer.get(), write_buffer->size(), write_callback.callback(),
213           TRAFFIC_ANNOTATION_FOR_TESTS);
214 
215       scoped_refptr<IOBufferWithSize> read_buffer =
216           base::MakeRefCounted<IOBufferWithSize>(message.size());
217       TestCompletionCallback read_callback;
218       int read_result = connecting_socket.Read(
219           read_buffer.get(), read_buffer->size(), read_callback.callback());
220 
221       ASSERT_EQ(1, write_callback.GetResult(write_result));
222       ASSERT_EQ(1, read_callback.GetResult(read_result));
223     }
224     EXPECT_EQ(expect_connection_changed_count,
225               watcher_ptr->connection_changed_count());
226     EXPECT_EQ(expect_rtt_notification_count,
227               watcher_ptr->rtt_notification_count());
228   }
229 #endif  // defined(TCP_INFO) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
230 
local_address_list() const231   AddressList local_address_list() const {
232     return AddressList(local_address_);
233   }
234 
235   TCPSocket socket_;
236   IPEndPoint local_address_;
237 };
238 
239 // Test listening and accepting with a socket bound to an IPv4 address.
TEST_F(TCPSocketTest,Accept)240 TEST_F(TCPSocketTest, Accept) {
241   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
242 
243   TestCompletionCallback connect_callback;
244   // TODO(yzshen): Switch to use TCPSocket when it supports client socket
245   // operations.
246   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
247                                     nullptr, NetLogSource());
248   int connect_result = connecting_socket.Connect(connect_callback.callback());
249 
250   TestCompletionCallback accept_callback;
251   std::unique_ptr<TCPSocket> accepted_socket;
252   IPEndPoint accepted_address;
253   int result = socket_.Accept(&accepted_socket, &accepted_address,
254                               accept_callback.callback());
255   ASSERT_THAT(accept_callback.GetResult(result), IsOk());
256 
257   EXPECT_TRUE(accepted_socket.get());
258 
259   // Both sockets should be on the loopback network interface.
260   EXPECT_EQ(accepted_address.address(), local_address_.address());
261 
262   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
263 }
264 
265 // Test Accept() callback.
TEST_F(TCPSocketTest,AcceptAsync)266 TEST_F(TCPSocketTest, AcceptAsync) {
267   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
268   TestAcceptAsync();
269 }
270 
271 // Test AdoptConnectedSocket()
TEST_F(TCPSocketTest,AdoptConnectedSocket)272 TEST_F(TCPSocketTest, AdoptConnectedSocket) {
273   TCPSocket accepting_socket(nullptr, nullptr, NetLogSource());
274   ASSERT_THAT(accepting_socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
275   ASSERT_THAT(accepting_socket.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
276               IsOk());
277   ASSERT_THAT(accepting_socket.GetLocalAddress(&local_address_), IsOk());
278   ASSERT_THAT(accepting_socket.Listen(kListenBacklog), IsOk());
279 
280   TestCompletionCallback connect_callback;
281   // TODO(yzshen): Switch to use TCPSocket when it supports client socket
282   // operations.
283   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
284                                     nullptr, NetLogSource());
285   int connect_result = connecting_socket.Connect(connect_callback.callback());
286 
287   TestCompletionCallback accept_callback;
288   std::unique_ptr<TCPSocket> accepted_socket;
289   IPEndPoint accepted_address;
290   int result = accepting_socket.Accept(&accepted_socket, &accepted_address,
291                                        accept_callback.callback());
292   ASSERT_THAT(accept_callback.GetResult(result), IsOk());
293 
294   SocketDescriptor accepted_descriptor =
295       accepted_socket->ReleaseSocketDescriptorForTesting();
296 
297   ASSERT_THAT(
298       socket_.AdoptConnectedSocket(accepted_descriptor, accepted_address),
299       IsOk());
300 
301   // socket_ should now have the local address.
302   IPEndPoint adopted_address;
303   ASSERT_THAT(socket_.GetLocalAddress(&adopted_address), IsOk());
304   EXPECT_EQ(local_address_.address(), adopted_address.address());
305 
306   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
307 }
308 
309 // Test Accept() for AdoptUnconnectedSocket.
TEST_F(TCPSocketTest,AcceptForAdoptedUnconnectedSocket)310 TEST_F(TCPSocketTest, AcceptForAdoptedUnconnectedSocket) {
311   SocketDescriptor existing_socket =
312       CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
313   ASSERT_THAT(socket_.AdoptUnconnectedSocket(existing_socket), IsOk());
314 
315   IPEndPoint address(IPAddress::IPv4Localhost(), 0);
316   SockaddrStorage storage;
317   ASSERT_TRUE(address.ToSockAddr(storage.addr, &storage.addr_len));
318   ASSERT_EQ(0, bind(existing_socket, storage.addr, storage.addr_len));
319 
320   ASSERT_THAT(socket_.Listen(kListenBacklog), IsOk());
321   ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
322 
323   TestAcceptAsync();
324 }
325 
326 // Accept two connections simultaneously.
TEST_F(TCPSocketTest,Accept2Connections)327 TEST_F(TCPSocketTest, Accept2Connections) {
328   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
329 
330   TestCompletionCallback accept_callback;
331   std::unique_ptr<TCPSocket> accepted_socket;
332   IPEndPoint accepted_address;
333 
334   ASSERT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
335                              accept_callback.callback()),
336               IsError(ERR_IO_PENDING));
337 
338   TestCompletionCallback connect_callback;
339   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
340                                     nullptr, NetLogSource());
341   int connect_result = connecting_socket.Connect(connect_callback.callback());
342 
343   TestCompletionCallback connect_callback2;
344   TCPClientSocket connecting_socket2(local_address_list(), nullptr, nullptr,
345                                      nullptr, NetLogSource());
346   int connect_result2 =
347       connecting_socket2.Connect(connect_callback2.callback());
348 
349   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
350 
351   TestCompletionCallback accept_callback2;
352   std::unique_ptr<TCPSocket> accepted_socket2;
353   IPEndPoint accepted_address2;
354 
355   int result = socket_.Accept(&accepted_socket2, &accepted_address2,
356                               accept_callback2.callback());
357   ASSERT_THAT(accept_callback2.GetResult(result), IsOk());
358 
359   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
360   EXPECT_THAT(connect_callback2.GetResult(connect_result2), IsOk());
361 
362   EXPECT_TRUE(accepted_socket.get());
363   EXPECT_TRUE(accepted_socket2.get());
364   EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
365 
366   EXPECT_EQ(accepted_address.address(), local_address_.address());
367   EXPECT_EQ(accepted_address2.address(), local_address_.address());
368 }
369 
370 // Test listening and accepting with a socket bound to an IPv6 address.
TEST_F(TCPSocketTest,AcceptIPv6)371 TEST_F(TCPSocketTest, AcceptIPv6) {
372   bool initialized = false;
373   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv6(&initialized));
374   if (!initialized)
375     return;
376 
377   TestCompletionCallback connect_callback;
378   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
379                                     nullptr, NetLogSource());
380   int connect_result = connecting_socket.Connect(connect_callback.callback());
381 
382   TestCompletionCallback accept_callback;
383   std::unique_ptr<TCPSocket> accepted_socket;
384   IPEndPoint accepted_address;
385   int result = socket_.Accept(&accepted_socket, &accepted_address,
386                               accept_callback.callback());
387   ASSERT_THAT(accept_callback.GetResult(result), IsOk());
388 
389   EXPECT_TRUE(accepted_socket.get());
390 
391   // Both sockets should be on the loopback network interface.
392   EXPECT_EQ(accepted_address.address(), local_address_.address());
393 
394   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
395 }
396 
TEST_F(TCPSocketTest,ReadWrite)397 TEST_F(TCPSocketTest, ReadWrite) {
398   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
399 
400   TestCompletionCallback connect_callback;
401   TCPSocket connecting_socket(nullptr, nullptr, NetLogSource());
402   int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4);
403   ASSERT_THAT(result, IsOk());
404   int connect_result =
405       connecting_socket.Connect(local_address_, connect_callback.callback());
406 
407   TestCompletionCallback accept_callback;
408   std::unique_ptr<TCPSocket> accepted_socket;
409   IPEndPoint accepted_address;
410   result = socket_.Accept(&accepted_socket, &accepted_address,
411                           accept_callback.callback());
412   ASSERT_THAT(accept_callback.GetResult(result), IsOk());
413 
414   ASSERT_TRUE(accepted_socket.get());
415 
416   // Both sockets should be on the loopback network interface.
417   EXPECT_EQ(accepted_address.address(), local_address_.address());
418 
419   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
420 
421   const std::string message("test message");
422   std::vector<char> buffer(message.size());
423 
424   size_t bytes_written = 0;
425   while (bytes_written < message.size()) {
426     scoped_refptr<IOBufferWithSize> write_buffer =
427         base::MakeRefCounted<IOBufferWithSize>(message.size() - bytes_written);
428     memmove(write_buffer->data(), message.data() + bytes_written,
429             message.size() - bytes_written);
430 
431     TestCompletionCallback write_callback;
432     int write_result = accepted_socket->Write(
433         write_buffer.get(), write_buffer->size(), write_callback.callback(),
434         TRAFFIC_ANNOTATION_FOR_TESTS);
435     write_result = write_callback.GetResult(write_result);
436     ASSERT_TRUE(write_result >= 0);
437     bytes_written += write_result;
438     ASSERT_TRUE(bytes_written <= message.size());
439   }
440 
441   size_t bytes_read = 0;
442   while (bytes_read < message.size()) {
443     scoped_refptr<IOBufferWithSize> read_buffer =
444         base::MakeRefCounted<IOBufferWithSize>(message.size() - bytes_read);
445     TestCompletionCallback read_callback;
446     int read_result = connecting_socket.Read(
447         read_buffer.get(), read_buffer->size(), read_callback.callback());
448     read_result = read_callback.GetResult(read_result);
449     ASSERT_TRUE(read_result >= 0);
450     ASSERT_TRUE(bytes_read + read_result <= message.size());
451     memmove(&buffer[bytes_read], read_buffer->data(), read_result);
452     bytes_read += read_result;
453   }
454 
455   std::string received_message(buffer.begin(), buffer.end());
456   ASSERT_EQ(message, received_message);
457 }
458 
459 // Destroy a TCPSocket while there's a pending read, and make sure the read
460 // IOBuffer that the socket was holding on to is destroyed.
461 // See https://crbug.com/804868.
TEST_F(TCPSocketTest,DestroyWithPendingRead)462 TEST_F(TCPSocketTest, DestroyWithPendingRead) {
463   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
464 
465   // Create a connected socket.
466 
467   TestCompletionCallback connect_callback;
468   std::unique_ptr<TCPSocket> connecting_socket =
469       std::make_unique<TCPSocket>(nullptr, nullptr, NetLogSource());
470   int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
471   ASSERT_THAT(result, IsOk());
472   int connect_result =
473       connecting_socket->Connect(local_address_, connect_callback.callback());
474 
475   TestCompletionCallback accept_callback;
476   std::unique_ptr<TCPSocket> accepted_socket;
477   IPEndPoint accepted_address;
478   result = socket_.Accept(&accepted_socket, &accepted_address,
479                           accept_callback.callback());
480   ASSERT_THAT(accept_callback.GetResult(result), IsOk());
481   ASSERT_TRUE(accepted_socket.get());
482   ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
483 
484   // Try to read from the socket, but never write anything to the other end.
485   base::RunLoop run_loop;
486   scoped_refptr<IOBufferWithDestructionCallback> read_buffer(
487       base::MakeRefCounted<IOBufferWithDestructionCallback>(
488           run_loop.QuitClosure()));
489   TestCompletionCallback read_callback;
490   EXPECT_EQ(ERR_IO_PENDING,
491             connecting_socket->Read(read_buffer.get(), read_buffer->size(),
492                                     read_callback.callback()));
493 
494   // Release the handle to the read buffer and destroy the socket. Make sure the
495   // read buffer is destroyed.
496   read_buffer = nullptr;
497   connecting_socket.reset();
498   run_loop.Run();
499 }
500 
501 // Destroy a TCPSocket while there's a pending write, and make sure the write
502 // IOBuffer that the socket was holding on to is destroyed.
TEST_F(TCPSocketTest,DestroyWithPendingWrite)503 TEST_F(TCPSocketTest, DestroyWithPendingWrite) {
504   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
505 
506   // Create a connected socket.
507 
508   TestCompletionCallback connect_callback;
509   std::unique_ptr<TCPSocket> connecting_socket =
510       std::make_unique<TCPSocket>(nullptr, nullptr, NetLogSource());
511   int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
512   ASSERT_THAT(result, IsOk());
513   int connect_result =
514       connecting_socket->Connect(local_address_, connect_callback.callback());
515 
516   TestCompletionCallback accept_callback;
517   std::unique_ptr<TCPSocket> accepted_socket;
518   IPEndPoint accepted_address;
519   result = socket_.Accept(&accepted_socket, &accepted_address,
520                           accept_callback.callback());
521   ASSERT_THAT(accept_callback.GetResult(result), IsOk());
522   ASSERT_TRUE(accepted_socket.get());
523   ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
524 
525   // Repeatedly write to the socket until an operation does not complete
526   // synchronously.
527   base::RunLoop run_loop;
528   scoped_refptr<IOBufferWithDestructionCallback> write_buffer(
529       base::MakeRefCounted<IOBufferWithDestructionCallback>(
530           run_loop.QuitClosure()));
531   memset(write_buffer->data(), '1', write_buffer->size());
532   TestCompletionCallback write_callback;
533   while (true) {
534     result = connecting_socket->Write(write_buffer.get(), write_buffer->size(),
535                                       write_callback.callback(),
536                                       TRAFFIC_ANNOTATION_FOR_TESTS);
537     if (result == ERR_IO_PENDING)
538       break;
539     ASSERT_LT(0, result);
540   }
541 
542   // Release the handle to the read buffer and destroy the socket. Make sure the
543   // write buffer is destroyed.
544   write_buffer = nullptr;
545   connecting_socket.reset();
546   run_loop.Run();
547 }
548 
549 // If a ReadIfReady is pending, it's legal to cancel it and start reading later.
TEST_F(TCPSocketTest,CancelPendingReadIfReady)550 TEST_F(TCPSocketTest, CancelPendingReadIfReady) {
551   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
552 
553   // Create a connected socket.
554   TestCompletionCallback connect_callback;
555   std::unique_ptr<TCPSocket> connecting_socket =
556       std::make_unique<TCPSocket>(nullptr, nullptr, NetLogSource());
557   int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
558   ASSERT_THAT(result, IsOk());
559   int connect_result =
560       connecting_socket->Connect(local_address_, connect_callback.callback());
561 
562   TestCompletionCallback accept_callback;
563   std::unique_ptr<TCPSocket> accepted_socket;
564   IPEndPoint accepted_address;
565   result = socket_.Accept(&accepted_socket, &accepted_address,
566                           accept_callback.callback());
567   ASSERT_THAT(accept_callback.GetResult(result), IsOk());
568   ASSERT_TRUE(accepted_socket.get());
569   ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
570 
571   // Try to read from the socket, but never write anything to the other end.
572   base::RunLoop run_loop;
573   scoped_refptr<IOBufferWithDestructionCallback> read_buffer(
574       base::MakeRefCounted<IOBufferWithDestructionCallback>(
575           run_loop.QuitClosure()));
576   TestCompletionCallback read_callback;
577   EXPECT_EQ(ERR_IO_PENDING, connecting_socket->ReadIfReady(
578                                 read_buffer.get(), read_buffer->size(),
579                                 read_callback.callback()));
580 
581   // Now cancel the pending ReadIfReady().
582   connecting_socket->CancelReadIfReady();
583 
584   // Send data to |connecting_socket|.
585   const char kMsg[] = "hello!";
586   scoped_refptr<StringIOBuffer> write_buffer =
587       base::MakeRefCounted<StringIOBuffer>(kMsg);
588 
589   TestCompletionCallback write_callback;
590   int write_result = accepted_socket->Write(write_buffer.get(), strlen(kMsg),
591                                             write_callback.callback(),
592                                             TRAFFIC_ANNOTATION_FOR_TESTS);
593   const int msg_size = strlen(kMsg);
594   ASSERT_EQ(msg_size, write_result);
595 
596   TestCompletionCallback read_callback2;
597   int read_result = connecting_socket->ReadIfReady(
598       read_buffer.get(), read_buffer->size(), read_callback2.callback());
599   if (read_result == ERR_IO_PENDING) {
600     ASSERT_EQ(OK, read_callback2.GetResult(read_result));
601     read_result = connecting_socket->ReadIfReady(
602         read_buffer.get(), read_buffer->size(), read_callback2.callback());
603   }
604 
605   ASSERT_EQ(msg_size, read_result);
606   ASSERT_EQ(0, memcmp(&kMsg, read_buffer->data(), msg_size));
607 }
608 
TEST_F(TCPSocketTest,IsConnected)609 TEST_F(TCPSocketTest, IsConnected) {
610   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
611 
612   TestCompletionCallback accept_callback;
613   std::unique_ptr<TCPSocket> accepted_socket;
614   IPEndPoint accepted_address;
615   EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
616                              accept_callback.callback()),
617               IsError(ERR_IO_PENDING));
618 
619   TestCompletionCallback connect_callback;
620   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
621                                     nullptr, NetLogSource());
622 
623   // Immediately after creation, the socket should not be connected.
624   EXPECT_FALSE(connecting_socket.IsConnected());
625   EXPECT_FALSE(connecting_socket.IsConnectedAndIdle());
626 
627   int connect_result = connecting_socket.Connect(connect_callback.callback());
628   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
629   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
630 
631   // |connecting_socket| and |accepted_socket| should now both be reported as
632   // connected, and idle
633   EXPECT_TRUE(accepted_socket->IsConnected());
634   EXPECT_TRUE(accepted_socket->IsConnectedAndIdle());
635   EXPECT_TRUE(connecting_socket.IsConnected());
636   EXPECT_TRUE(connecting_socket.IsConnectedAndIdle());
637 
638   // Write one byte to the |accepted_socket|, then close it.
639   const char kSomeData[] = "!";
640   scoped_refptr<IOBuffer> some_data_buffer =
641       base::MakeRefCounted<StringIOBuffer>(kSomeData);
642   TestCompletionCallback write_callback;
643   EXPECT_THAT(write_callback.GetResult(accepted_socket->Write(
644                   some_data_buffer.get(), 1, write_callback.callback(),
645                   TRAFFIC_ANNOTATION_FOR_TESTS)),
646               1);
647   accepted_socket.reset();
648 
649   // Wait until |connecting_socket| is signalled as having data to read.
650   fd_set read_fds;
651   FD_ZERO(&read_fds);
652   SocketDescriptor connecting_fd =
653       connecting_socket.SocketDescriptorForTesting();
654   FD_SET(connecting_fd, &read_fds);
655   ASSERT_EQ(select(FD_SETSIZE, &read_fds, nullptr, nullptr, nullptr), 1);
656   ASSERT_TRUE(FD_ISSET(connecting_fd, &read_fds));
657 
658   // It should now be reported as connected, but not as idle.
659   EXPECT_TRUE(connecting_socket.IsConnected());
660   EXPECT_FALSE(connecting_socket.IsConnectedAndIdle());
661 
662   // Read the message from |connecting_socket_|, then read the end-of-stream.
663   scoped_refptr<IOBufferWithSize> read_buffer =
664       base::MakeRefCounted<IOBufferWithSize>(2);
665   TestCompletionCallback read_callback;
666   EXPECT_THAT(
667       read_callback.GetResult(connecting_socket.Read(
668           read_buffer.get(), read_buffer->size(), read_callback.callback())),
669       1);
670   EXPECT_THAT(
671       read_callback.GetResult(connecting_socket.Read(
672           read_buffer.get(), read_buffer->size(), read_callback.callback())),
673       0);
674 
675   // |connecting_socket| has no more data to read, so should noe be reported
676   // as disconnected.
677   EXPECT_FALSE(connecting_socket.IsConnected());
678   EXPECT_FALSE(connecting_socket.IsConnectedAndIdle());
679 }
680 
681 // Tests that setting a socket option in the BeforeConnectCallback works. With
682 // real sockets, socket options often have to be set before the connect() call,
683 // and the BeforeConnectCallback is the only way to do that, with a
684 // TCPClientSocket.
TEST_F(TCPSocketTest,BeforeConnectCallback)685 TEST_F(TCPSocketTest, BeforeConnectCallback) {
686   // A receive buffer size that is between max and minimum buffer size limits,
687   // and weird enough to likely not be a default value.
688   const int kReceiveBufferSize = 32 * 1024 + 1117;
689   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
690 
691   TestCompletionCallback accept_callback;
692   std::unique_ptr<TCPSocket> accepted_socket;
693   IPEndPoint accepted_address;
694   EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
695                              accept_callback.callback()),
696               IsError(ERR_IO_PENDING));
697 
698   TestCompletionCallback connect_callback;
699   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
700                                     nullptr, NetLogSource());
701 
702   connecting_socket.SetBeforeConnectCallback(base::BindLambdaForTesting([&] {
703     EXPECT_FALSE(connecting_socket.IsConnected());
704     int result = connecting_socket.SetReceiveBufferSize(kReceiveBufferSize);
705     EXPECT_THAT(result, IsOk());
706     return result;
707   }));
708   int connect_result = connecting_socket.Connect(connect_callback.callback());
709 
710   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
711   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
712 
713   int actual_size = 0;
714   socklen_t actual_size_len = sizeof(actual_size);
715   int os_result = getsockopt(
716       connecting_socket.SocketDescriptorForTesting(), SOL_SOCKET, SO_RCVBUF,
717       reinterpret_cast<char*>(&actual_size), &actual_size_len);
718   ASSERT_EQ(0, os_result);
719 // Linux platforms generally allocate twice as much buffer size is requested to
720 // account for internal kernel data structures.
721 #if BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS) || BUILDFLAG(IS_ANDROID)
722   EXPECT_EQ(2 * kReceiveBufferSize, actual_size);
723 // Unfortunately, Apple platform behavior doesn't seem to be documented, and
724 // doesn't match behavior on any other platforms.
725 // Fuchsia doesn't currently implement SO_RCVBUF.
726 #elif !BUILDFLAG(IS_APPLE) && !BUILDFLAG(IS_FUCHSIA)
727   EXPECT_EQ(kReceiveBufferSize, actual_size);
728 #endif
729 }
730 
TEST_F(TCPSocketTest,BeforeConnectCallbackFails)731 TEST_F(TCPSocketTest, BeforeConnectCallbackFails) {
732   // Setting up a server isn't strictly necessary, but it does allow checking
733   // the server was never connected to.
734   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
735 
736   TestCompletionCallback accept_callback;
737   std::unique_ptr<TCPSocket> accepted_socket;
738   IPEndPoint accepted_address;
739   EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
740                              accept_callback.callback()),
741               IsError(ERR_IO_PENDING));
742 
743   TestCompletionCallback connect_callback;
744   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
745                                     nullptr, NetLogSource());
746 
747   // Set a callback that returns a nonsensical error, and make sure it's
748   // returned.
749   connecting_socket.SetBeforeConnectCallback(base::BindRepeating(
750       [] { return static_cast<int>(net::ERR_NAME_NOT_RESOLVED); }));
751   int connect_result = connecting_socket.Connect(connect_callback.callback());
752   EXPECT_THAT(connect_callback.GetResult(connect_result),
753               IsError(net::ERR_NAME_NOT_RESOLVED));
754 
755   // Best effort check that the socket wasn't accepted - may flakily pass on
756   // regression, unfortunately.
757   base::RunLoop().RunUntilIdle();
758   EXPECT_FALSE(accept_callback.have_result());
759 }
760 
TEST_F(TCPSocketTest,SetKeepAlive)761 TEST_F(TCPSocketTest, SetKeepAlive) {
762   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
763 
764   TestCompletionCallback accept_callback;
765   std::unique_ptr<TCPSocket> accepted_socket;
766   IPEndPoint accepted_address;
767   EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
768                              accept_callback.callback()),
769               IsError(ERR_IO_PENDING));
770 
771   TestCompletionCallback connect_callback;
772   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
773                                     nullptr, NetLogSource());
774 
775   // Non-connected sockets should not be able to set KeepAlive.
776   ASSERT_FALSE(connecting_socket.IsConnected());
777   EXPECT_FALSE(
778       connecting_socket.SetKeepAlive(true /* enable */, 14 /* delay */));
779 
780   // Connect.
781   int connect_result = connecting_socket.Connect(connect_callback.callback());
782   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
783   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
784 
785   // Connected sockets should be able to enable and disable KeepAlive.
786   ASSERT_TRUE(connecting_socket.IsConnected());
787   EXPECT_TRUE(
788       connecting_socket.SetKeepAlive(true /* enable */, 22 /* delay */));
789   EXPECT_TRUE(
790       connecting_socket.SetKeepAlive(false /* enable */, 3 /* delay */));
791 }
792 
TEST_F(TCPSocketTest,SetNoDelay)793 TEST_F(TCPSocketTest, SetNoDelay) {
794   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
795 
796   TestCompletionCallback accept_callback;
797   std::unique_ptr<TCPSocket> accepted_socket;
798   IPEndPoint accepted_address;
799   EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
800                              accept_callback.callback()),
801               IsError(ERR_IO_PENDING));
802 
803   TestCompletionCallback connect_callback;
804   TCPClientSocket connecting_socket(local_address_list(), nullptr, nullptr,
805                                     nullptr, NetLogSource());
806 
807   // Non-connected sockets should not be able to set NoDelay.
808   ASSERT_FALSE(connecting_socket.IsConnected());
809   EXPECT_FALSE(connecting_socket.SetNoDelay(true /* no_delay */));
810 
811   // Connect.
812   int connect_result = connecting_socket.Connect(connect_callback.callback());
813   EXPECT_THAT(accept_callback.WaitForResult(), IsOk());
814   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
815 
816   // Connected sockets should be able to enable and disable NoDelay.
817   ASSERT_TRUE(connecting_socket.IsConnected());
818   EXPECT_TRUE(connecting_socket.SetNoDelay(true /* no_delay */));
819   EXPECT_TRUE(connecting_socket.SetNoDelay(false /* no_delay */));
820 }
821 
822 // These tests require kernel support for tcp_info struct, and so they are
823 // enabled only on certain platforms.
824 #if defined(TCP_INFO) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
825 // If SocketPerformanceWatcher::ShouldNotifyUpdatedRTT always returns false,
826 // then the wtatcher should not receive any notifications.
TEST_F(TCPSocketTest,SPWNotInterested)827 TEST_F(TCPSocketTest, SPWNotInterested) {
828   TestSPWNotifications(false, 2u, 0u, 0u);
829 }
830 
831 // One notification should be received when the socket connects. One
832 // additional notification should be received for each message read.
TEST_F(TCPSocketTest,SPWNoAdvance)833 TEST_F(TCPSocketTest, SPWNoAdvance) {
834   TestSPWNotifications(true, 2u, 0u, 3u);
835 }
836 #endif  // defined(TCP_INFO) || BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS)
837 
838 // On Android, where socket tagging is supported, verify that TCPSocket::Tag
839 // works as expected.
840 #if BUILDFLAG(IS_ANDROID)
TEST_F(TCPSocketTest,Tag)841 TEST_F(TCPSocketTest, Tag) {
842   if (!CanGetTaggedBytes()) {
843     DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
844     return;
845   }
846 
847   // Start test server.
848   EmbeddedTestServer test_server;
849   test_server.AddDefaultHandlers(base::FilePath());
850   ASSERT_TRUE(test_server.Start());
851 
852   AddressList addr_list;
853   ASSERT_TRUE(test_server.GetAddressList(&addr_list));
854   EXPECT_EQ(socket_.Open(addr_list[0].GetFamily()), OK);
855 
856   // Verify TCP connect packets are tagged and counted properly.
857   int32_t tag_val1 = 0x12345678;
858   uint64_t old_traffic = GetTaggedBytes(tag_val1);
859   SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
860   socket_.ApplySocketTag(tag1);
861   TestCompletionCallback connect_callback;
862   int connect_result =
863       socket_.Connect(addr_list[0], connect_callback.callback());
864   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
865   EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
866 
867   // Verify socket can be retagged with a new value and the current process's
868   // UID.
869   int32_t tag_val2 = 0x87654321;
870   old_traffic = GetTaggedBytes(tag_val2);
871   SocketTag tag2(getuid(), tag_val2);
872   socket_.ApplySocketTag(tag2);
873   const char kRequest1[] = "GET / HTTP/1.0";
874   scoped_refptr<IOBuffer> write_buffer1 =
875       base::MakeRefCounted<StringIOBuffer>(kRequest1);
876   TestCompletionCallback write_callback1;
877   EXPECT_EQ(
878       socket_.Write(write_buffer1.get(), strlen(kRequest1),
879                     write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
880       static_cast<int>(strlen(kRequest1)));
881   EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
882 
883   // Verify socket can be retagged with a new value and the current process's
884   // UID.
885   old_traffic = GetTaggedBytes(tag_val1);
886   socket_.ApplySocketTag(tag1);
887   const char kRequest2[] = "\n\n";
888   scoped_refptr<IOBuffer> write_buffer2 =
889       base::MakeRefCounted<StringIOBuffer>(kRequest2);
890   TestCompletionCallback write_callback2;
891   EXPECT_EQ(
892       socket_.Write(write_buffer2.get(), strlen(kRequest2),
893                     write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
894       static_cast<int>(strlen(kRequest2)));
895   EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
896 
897   socket_.Close();
898 }
899 
TEST_F(TCPSocketTest,TagAfterConnect)900 TEST_F(TCPSocketTest, TagAfterConnect) {
901   if (!CanGetTaggedBytes()) {
902     DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
903     return;
904   }
905 
906   // Start test server.
907   EmbeddedTestServer test_server;
908   test_server.AddDefaultHandlers(base::FilePath());
909   ASSERT_TRUE(test_server.Start());
910 
911   AddressList addr_list;
912   ASSERT_TRUE(test_server.GetAddressList(&addr_list));
913   EXPECT_EQ(socket_.Open(addr_list[0].GetFamily()), OK);
914 
915   // Connect socket.
916   TestCompletionCallback connect_callback;
917   int connect_result =
918       socket_.Connect(addr_list[0], connect_callback.callback());
919   EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
920 
921   // Verify socket can be tagged with a new value and the current process's
922   // UID.
923   int32_t tag_val2 = 0x87654321;
924   uint64_t old_traffic = GetTaggedBytes(tag_val2);
925   SocketTag tag2(getuid(), tag_val2);
926   socket_.ApplySocketTag(tag2);
927   const char kRequest1[] = "GET / HTTP/1.0";
928   scoped_refptr<IOBuffer> write_buffer1 =
929       base::MakeRefCounted<StringIOBuffer>(kRequest1);
930   TestCompletionCallback write_callback1;
931   EXPECT_EQ(
932       socket_.Write(write_buffer1.get(), strlen(kRequest1),
933                     write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
934       static_cast<int>(strlen(kRequest1)));
935   EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
936 
937   // Verify socket can be retagged with a new value and the current process's
938   // UID.
939   int32_t tag_val1 = 0x12345678;
940   old_traffic = GetTaggedBytes(tag_val1);
941   SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
942   socket_.ApplySocketTag(tag1);
943   const char kRequest2[] = "\n\n";
944   scoped_refptr<IOBuffer> write_buffer2 =
945       base::MakeRefCounted<StringIOBuffer>(kRequest2);
946   TestCompletionCallback write_callback2;
947   EXPECT_EQ(
948       socket_.Write(write_buffer2.get(), strlen(kRequest2),
949                     write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
950       static_cast<int>(strlen(kRequest2)));
951   EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
952 
953   socket_.Close();
954 }
955 
TEST_F(TCPSocketTest,BindToNetwork)956 TEST_F(TCPSocketTest, BindToNetwork) {
957   NetworkChangeNotifierFactoryAndroid ncn_factory;
958   NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
959   std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
960   if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
961     GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
962 
963   const handles::NetworkHandle wrong_network_handle = 65536;
964   // Try binding to this IP to trigger the underlying BindToNetwork call.
965   const IPEndPoint ip(IPAddress::IPv4Localhost(), 0);
966   // TestCompletionCallback connect_callback;
967   TCPClientSocket wrong_socket(local_address_list(), nullptr, nullptr, nullptr,
968                                NetLogSource(), wrong_network_handle);
969   // Different Android versions might report different errors. Hence, just check
970   // what shouldn't happen.
971   int rv = wrong_socket.Bind(ip);
972   EXPECT_NE(OK, rv);
973   EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
974 
975   // Connecting using an existing network should succeed.
976   const handles::NetworkHandle network_handle =
977       NetworkChangeNotifier::GetDefaultNetwork();
978   if (network_handle != handles::kInvalidNetworkHandle) {
979     TCPClientSocket correct_socket(local_address_list(), nullptr, nullptr,
980                                    nullptr, NetLogSource(), network_handle);
981     EXPECT_EQ(OK, correct_socket.Bind(ip));
982   }
983 }
984 
985 #endif  // BUILDFLAG(IS_ANDROID)
986 
987 }  // namespace
988 }  // namespace net
989