1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/udp_socket.h"
6
7 #include <algorithm>
8
9 #include "base/containers/circular_deque.h"
10 #include "base/functional/bind.h"
11 #include "base/location.h"
12 #include "base/memory/raw_ptr.h"
13 #include "base/memory/weak_ptr.h"
14 #include "base/run_loop.h"
15 #include "base/scoped_clear_last_error.h"
16 #include "base/strings/string_number_conversions.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "base/test/scoped_feature_list.h"
19 #include "base/threading/thread.h"
20 #include "base/time/time.h"
21 #include "build/build_config.h"
22 #include "build/chromeos_buildflags.h"
23 #include "net/base/features.h"
24 #include "net/base/io_buffer.h"
25 #include "net/base/ip_address.h"
26 #include "net/base/ip_endpoint.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/network_interfaces.h"
29 #include "net/base/test_completion_callback.h"
30 #include "net/log/net_log_event_type.h"
31 #include "net/log/net_log_source.h"
32 #include "net/log/test_net_log.h"
33 #include "net/log/test_net_log_util.h"
34 #include "net/socket/socket_test_util.h"
35 #include "net/socket/udp_client_socket.h"
36 #include "net/socket/udp_server_socket.h"
37 #include "net/socket/udp_socket_global_limits.h"
38 #include "net/test/gtest_util.h"
39 #include "net/test/test_with_task_environment.h"
40 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
41 #include "testing/gmock/include/gmock/gmock.h"
42 #include "testing/gtest/include/gtest/gtest.h"
43 #include "testing/platform_test.h"
44
45 #if !BUILDFLAG(IS_WIN)
46 #include <netinet/in.h>
47 #include <sys/socket.h>
48 #else
49 #include <winsock2.h>
50 #endif
51
52 #if BUILDFLAG(IS_ANDROID)
53 #include "base/android/build_info.h"
54 #include "net/android/network_change_notifier_factory_android.h"
55 #include "net/base/network_change_notifier.h"
56 #endif
57
58 #if BUILDFLAG(IS_IOS)
59 #include <TargetConditionals.h>
60 #endif
61
62 #if BUILDFLAG(IS_MAC)
63 #include "base/mac/mac_util.h"
64 #endif // BUILDFLAG(IS_MAC)
65
66 using net::test::IsError;
67 using net::test::IsOk;
68 using testing::DoAll;
69 using testing::Not;
70
71 namespace net {
72
73 namespace {
74
75 // Creates an address from ip address and port and writes it to |*address|.
CreateUDPAddress(const std::string & ip_str,uint16_t port,IPEndPoint * address)76 bool CreateUDPAddress(const std::string& ip_str,
77 uint16_t port,
78 IPEndPoint* address) {
79 IPAddress ip_address;
80 if (!ip_address.AssignFromIPLiteral(ip_str))
81 return false;
82
83 *address = IPEndPoint(ip_address, port);
84 return true;
85 }
86
87 class UDPSocketTest : public PlatformTest, public WithTaskEnvironment {
88 public:
UDPSocketTest()89 UDPSocketTest() : buffer_(base::MakeRefCounted<IOBufferWithSize>(kMaxRead)) {}
90
91 // Blocks until data is read from the socket.
RecvFromSocket(UDPServerSocket * socket)92 std::string RecvFromSocket(UDPServerSocket* socket) {
93 TestCompletionCallback callback;
94
95 int rv = socket->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
96 callback.callback());
97 rv = callback.GetResult(rv);
98 if (rv < 0)
99 return std::string();
100 return std::string(buffer_->data(), rv);
101 }
102
103 // Sends UDP packet.
104 // If |address| is specified, then it is used for the destination
105 // to send to. Otherwise, will send to the last socket this server
106 // received from.
SendToSocket(UDPServerSocket * socket,const std::string & msg)107 int SendToSocket(UDPServerSocket* socket, const std::string& msg) {
108 return SendToSocket(socket, msg, recv_from_address_);
109 }
110
SendToSocket(UDPServerSocket * socket,std::string msg,const IPEndPoint & address)111 int SendToSocket(UDPServerSocket* socket,
112 std::string msg,
113 const IPEndPoint& address) {
114 scoped_refptr<StringIOBuffer> io_buffer =
115 base::MakeRefCounted<StringIOBuffer>(msg);
116 TestCompletionCallback callback;
117 int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
118 callback.callback());
119 return callback.GetResult(rv);
120 }
121
ReadSocket(UDPClientSocket * socket)122 std::string ReadSocket(UDPClientSocket* socket) {
123 return ReadSocket(socket, DSCP_DEFAULT, ECN_DEFAULT);
124 }
125
ReadSocket(UDPClientSocket * socket,DiffServCodePoint dscp,EcnCodePoint ecn)126 std::string ReadSocket(UDPClientSocket* socket,
127 DiffServCodePoint dscp,
128 EcnCodePoint ecn) {
129 TestCompletionCallback callback;
130
131 int rv = socket->Read(buffer_.get(), kMaxRead, callback.callback());
132 rv = callback.GetResult(rv);
133 if (rv < 0)
134 return std::string();
135 #if BUILDFLAG(IS_WIN)
136 // The DSCP value is not populated on Windows, in order to avoid incurring
137 // an extra system call.
138 EXPECT_EQ(socket->GetLastTos().dscp, DSCP_DEFAULT);
139 #else
140 EXPECT_EQ(socket->GetLastTos().dscp, dscp);
141 #endif
142 EXPECT_EQ(socket->GetLastTos().ecn, ecn);
143 return std::string(buffer_->data(), rv);
144 }
145
146 // Writes specified message to the socket.
WriteSocket(UDPClientSocket * socket,const std::string & msg)147 int WriteSocket(UDPClientSocket* socket, const std::string& msg) {
148 scoped_refptr<StringIOBuffer> io_buffer =
149 base::MakeRefCounted<StringIOBuffer>(msg);
150 TestCompletionCallback callback;
151 int rv = socket->Write(io_buffer.get(), io_buffer->size(),
152 callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
153 return callback.GetResult(rv);
154 }
155
WriteSocketIgnoreResult(UDPClientSocket * socket,const std::string & msg)156 void WriteSocketIgnoreResult(UDPClientSocket* socket,
157 const std::string& msg) {
158 WriteSocket(socket, msg);
159 }
160
161 // And again for a bare socket
SendToSocket(UDPSocket * socket,std::string msg,const IPEndPoint & address)162 int SendToSocket(UDPSocket* socket,
163 std::string msg,
164 const IPEndPoint& address) {
165 auto io_buffer = base::MakeRefCounted<StringIOBuffer>(msg);
166 TestCompletionCallback callback;
167 int rv = socket->SendTo(io_buffer.get(), io_buffer->size(), address,
168 callback.callback());
169 return callback.GetResult(rv);
170 }
171
172 // Run unit test for a connection test.
173 // |use_nonblocking_io| is used to switch between overlapped and non-blocking
174 // IO on Windows. It has no effect in other ports.
175 void ConnectTest(bool use_nonblocking_io, bool use_async);
176
177 protected:
178 static const int kMaxRead = 1024;
179 scoped_refptr<IOBufferWithSize> buffer_;
180 IPEndPoint recv_from_address_;
181 };
182
183 const int UDPSocketTest::kMaxRead;
184
ReadCompleteCallback(int * result_out,base::OnceClosure callback,int result)185 void ReadCompleteCallback(int* result_out,
186 base::OnceClosure callback,
187 int result) {
188 *result_out = result;
189 std::move(callback).Run();
190 }
191
ConnectTest(bool use_nonblocking_io,bool use_async)192 void UDPSocketTest::ConnectTest(bool use_nonblocking_io, bool use_async) {
193 std::string simple_message("hello world!");
194 RecordingNetLogObserver net_log_observer;
195 // Setup the server to listen.
196 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
197 auto server =
198 std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
199 if (use_nonblocking_io)
200 server->UseNonBlockingIO();
201 server->AllowAddressReuse();
202 ASSERT_THAT(server->Listen(server_address), IsOk());
203 // Get bound port.
204 ASSERT_THAT(server->GetLocalAddress(&server_address), IsOk());
205
206 // Setup the client.
207 auto client = std::make_unique<UDPClientSocket>(
208 DatagramSocket::DEFAULT_BIND, NetLog::Get(), NetLogSource());
209 if (use_nonblocking_io)
210 client->UseNonBlockingIO();
211
212 if (!use_async) {
213 EXPECT_THAT(client->Connect(server_address), IsOk());
214 } else {
215 TestCompletionCallback callback;
216 int rv = client->ConnectAsync(server_address, callback.callback());
217 if (rv != OK) {
218 ASSERT_EQ(rv, ERR_IO_PENDING);
219 rv = callback.WaitForResult();
220 EXPECT_EQ(rv, OK);
221 } else {
222 EXPECT_EQ(rv, OK);
223 }
224 }
225 // Client sends to the server.
226 EXPECT_EQ(simple_message.length(),
227 static_cast<size_t>(WriteSocket(client.get(), simple_message)));
228
229 // Server waits for message.
230 std::string str = RecvFromSocket(server.get());
231 EXPECT_EQ(simple_message, str);
232
233 // Server echoes reply.
234 EXPECT_EQ(simple_message.length(),
235 static_cast<size_t>(SendToSocket(server.get(), simple_message)));
236
237 // Client waits for response.
238 str = ReadSocket(client.get());
239 EXPECT_EQ(simple_message, str);
240
241 // Test asynchronous read. Server waits for message.
242 base::RunLoop run_loop;
243 int read_result = 0;
244 int rv = server->RecvFrom(buffer_.get(), kMaxRead, &recv_from_address_,
245 base::BindOnce(&ReadCompleteCallback, &read_result,
246 run_loop.QuitClosure()));
247 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
248
249 // Client sends to the server.
250 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
251 FROM_HERE,
252 base::BindOnce(&UDPSocketTest::WriteSocketIgnoreResult,
253 base::Unretained(this), client.get(), simple_message));
254 run_loop.Run();
255 EXPECT_EQ(simple_message.length(), static_cast<size_t>(read_result));
256 EXPECT_EQ(simple_message, std::string(buffer_->data(), read_result));
257
258 NetLogSource server_net_log_source = server->NetLog().source();
259 NetLogSource client_net_log_source = client->NetLog().source();
260
261 // Delete sockets so they log their final events.
262 server.reset();
263 client.reset();
264
265 // Check the server's log.
266 auto server_entries =
267 net_log_observer.GetEntriesForSource(server_net_log_source);
268 ASSERT_EQ(6u, server_entries.size());
269 EXPECT_TRUE(
270 LogContainsBeginEvent(server_entries, 0, NetLogEventType::SOCKET_ALIVE));
271 EXPECT_TRUE(LogContainsEvent(server_entries, 1,
272 NetLogEventType::UDP_LOCAL_ADDRESS,
273 NetLogEventPhase::NONE));
274 EXPECT_TRUE(LogContainsEvent(server_entries, 2,
275 NetLogEventType::UDP_BYTES_RECEIVED,
276 NetLogEventPhase::NONE));
277 EXPECT_TRUE(LogContainsEvent(server_entries, 3,
278 NetLogEventType::UDP_BYTES_SENT,
279 NetLogEventPhase::NONE));
280 EXPECT_TRUE(LogContainsEvent(server_entries, 4,
281 NetLogEventType::UDP_BYTES_RECEIVED,
282 NetLogEventPhase::NONE));
283 EXPECT_TRUE(
284 LogContainsEndEvent(server_entries, 5, NetLogEventType::SOCKET_ALIVE));
285
286 // Check the client's log.
287 auto client_entries =
288 net_log_observer.GetEntriesForSource(client_net_log_source);
289 EXPECT_EQ(7u, client_entries.size());
290 EXPECT_TRUE(
291 LogContainsBeginEvent(client_entries, 0, NetLogEventType::SOCKET_ALIVE));
292 EXPECT_TRUE(
293 LogContainsBeginEvent(client_entries, 1, NetLogEventType::UDP_CONNECT));
294 EXPECT_TRUE(
295 LogContainsEndEvent(client_entries, 2, NetLogEventType::UDP_CONNECT));
296 EXPECT_TRUE(LogContainsEvent(client_entries, 3,
297 NetLogEventType::UDP_BYTES_SENT,
298 NetLogEventPhase::NONE));
299 EXPECT_TRUE(LogContainsEvent(client_entries, 4,
300 NetLogEventType::UDP_BYTES_RECEIVED,
301 NetLogEventPhase::NONE));
302 EXPECT_TRUE(LogContainsEvent(client_entries, 5,
303 NetLogEventType::UDP_BYTES_SENT,
304 NetLogEventPhase::NONE));
305 EXPECT_TRUE(
306 LogContainsEndEvent(client_entries, 6, NetLogEventType::SOCKET_ALIVE));
307 }
308
TEST_F(UDPSocketTest,Connect)309 TEST_F(UDPSocketTest, Connect) {
310 // The variable |use_nonblocking_io| has no effect in non-Windows ports.
311 // Run ConnectTest once with sync connect and once with async connect
312 ConnectTest(false, false);
313 ConnectTest(false, true);
314 }
315
316 #if BUILDFLAG(IS_WIN)
TEST_F(UDPSocketTest,ConnectNonBlocking)317 TEST_F(UDPSocketTest, ConnectNonBlocking) {
318 ConnectTest(true, false);
319 ConnectTest(true, true);
320 }
321 #endif
322
TEST_F(UDPSocketTest,PartialRecv)323 TEST_F(UDPSocketTest, PartialRecv) {
324 UDPServerSocket server_socket(nullptr, NetLogSource());
325 ASSERT_THAT(server_socket.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
326 IsOk());
327 IPEndPoint server_address;
328 ASSERT_THAT(server_socket.GetLocalAddress(&server_address), IsOk());
329
330 UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
331 NetLogSource());
332 ASSERT_THAT(client_socket.Connect(server_address), IsOk());
333
334 std::string test_packet("hello world!");
335 ASSERT_EQ(static_cast<int>(test_packet.size()),
336 WriteSocket(&client_socket, test_packet));
337
338 TestCompletionCallback recv_callback;
339
340 // Read just 2 bytes. Read() is expected to return the first 2 bytes from the
341 // packet and discard the rest.
342 const int kPartialReadSize = 2;
343 auto buffer = base::MakeRefCounted<IOBufferWithSize>(kPartialReadSize);
344 int rv =
345 server_socket.RecvFrom(buffer.get(), kPartialReadSize,
346 &recv_from_address_, recv_callback.callback());
347 rv = recv_callback.GetResult(rv);
348
349 EXPECT_EQ(rv, ERR_MSG_TOO_BIG);
350
351 // Send a different message again.
352 std::string second_packet("Second packet");
353 ASSERT_EQ(static_cast<int>(second_packet.size()),
354 WriteSocket(&client_socket, second_packet));
355
356 // Read whole packet now.
357 std::string received = RecvFromSocket(&server_socket);
358 EXPECT_EQ(second_packet, received);
359 }
360
361 #if BUILDFLAG(IS_APPLE) || BUILDFLAG(IS_ANDROID)
362 // - MacOS: requires root permissions on OSX 10.7+.
363 // - Android: devices attached to testbots don't have default network, so
364 // broadcasting to 255.255.255.255 returns error -109 (Address not reachable).
365 // crbug.com/139144.
366 #define MAYBE_LocalBroadcast DISABLED_LocalBroadcast
367 #else
368 #define MAYBE_LocalBroadcast LocalBroadcast
369 #endif
TEST_F(UDPSocketTest,MAYBE_LocalBroadcast)370 TEST_F(UDPSocketTest, MAYBE_LocalBroadcast) {
371 std::string first_message("first message"), second_message("second message");
372
373 IPEndPoint listen_address;
374 ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &listen_address));
375
376 auto server1 =
377 std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
378 auto server2 =
379 std::make_unique<UDPServerSocket>(NetLog::Get(), NetLogSource());
380 server1->AllowAddressReuse();
381 server1->AllowBroadcast();
382 server2->AllowAddressReuse();
383 server2->AllowBroadcast();
384
385 EXPECT_THAT(server1->Listen(listen_address), IsOk());
386 // Get bound port.
387 EXPECT_THAT(server1->GetLocalAddress(&listen_address), IsOk());
388 EXPECT_THAT(server2->Listen(listen_address), IsOk());
389
390 IPEndPoint broadcast_address;
391 ASSERT_TRUE(CreateUDPAddress("127.255.255.255", listen_address.port(),
392 &broadcast_address));
393 ASSERT_EQ(static_cast<int>(first_message.size()),
394 SendToSocket(server1.get(), first_message, broadcast_address));
395 std::string str = RecvFromSocket(server1.get());
396 ASSERT_EQ(first_message, str);
397 str = RecvFromSocket(server2.get());
398 ASSERT_EQ(first_message, str);
399
400 ASSERT_EQ(static_cast<int>(second_message.size()),
401 SendToSocket(server2.get(), second_message, broadcast_address));
402 str = RecvFromSocket(server1.get());
403 ASSERT_EQ(second_message, str);
404 str = RecvFromSocket(server2.get());
405 ASSERT_EQ(second_message, str);
406 }
407
408 // ConnectRandomBind verifies RANDOM_BIND is handled correctly. It connects
409 // 1000 sockets and then verifies that the allocated port numbers satisfy the
410 // following 2 conditions:
411 // 1. Range from min port value to max is greater than 10000.
412 // 2. There is at least one port in the 5 buckets in the [min, max] range.
413 //
414 // These conditions are not enough to verify that the port numbers are truly
415 // random, but they are enough to protect from most common non-random port
416 // allocation strategies (e.g. counter, pool of available ports, etc.) False
417 // positive result is theoretically possible, but its probability is negligible.
TEST_F(UDPSocketTest,ConnectRandomBind)418 TEST_F(UDPSocketTest, ConnectRandomBind) {
419 const int kIterations = 1000;
420
421 std::vector<int> used_ports;
422 for (int i = 0; i < kIterations; ++i) {
423 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
424 NetLogSource());
425 EXPECT_THAT(socket.Connect(IPEndPoint(IPAddress::IPv4Localhost(), 53)),
426 IsOk());
427
428 IPEndPoint client_address;
429 EXPECT_THAT(socket.GetLocalAddress(&client_address), IsOk());
430 used_ports.push_back(client_address.port());
431 }
432
433 int min_port = *std::min_element(used_ports.begin(), used_ports.end());
434 int max_port = *std::max_element(used_ports.begin(), used_ports.end());
435 int range = max_port - min_port + 1;
436
437 // Verify that the range of ports used by the random port allocator is wider
438 // than 10k. Assuming that socket implementation limits port range to 16k
439 // ports (default on Fuchsia) probability of false negative is below
440 // 10^-200.
441 static int kMinRange = 10000;
442 EXPECT_GT(range, kMinRange);
443
444 static int kBuckets = 5;
445 std::vector<int> bucket_sizes(kBuckets, 0);
446 for (int port : used_ports) {
447 bucket_sizes[(port - min_port) * kBuckets / range] += 1;
448 }
449
450 // Verify that there is at least one value in each bucket. Probability of
451 // false negative is below (kBuckets * (1 - 1 / kBuckets) ^ kIterations),
452 // which is less than 10^-96.
453 for (int size : bucket_sizes) {
454 EXPECT_GT(size, 0);
455 }
456 }
457
TEST_F(UDPSocketTest,ConnectFail)458 TEST_F(UDPSocketTest, ConnectFail) {
459 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
460
461 EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
462
463 // Connect to an IPv6 address should fail since the socket was created for
464 // IPv4.
465 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
466 Not(IsOk()));
467
468 // Make sure that UDPSocket actually closed the socket.
469 EXPECT_FALSE(socket.is_connected());
470 }
471
472 // Similar to ConnectFail but UDPSocket adopts an opened socket instead of
473 // opening one directly.
TEST_F(UDPSocketTest,AdoptedSocket)474 TEST_F(UDPSocketTest, AdoptedSocket) {
475 auto socketfd =
476 CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
477 SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
478 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
479
480 EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd), IsOk());
481
482 // Connect to an IPv6 address should fail since the socket was created for
483 // IPv4.
484 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
485 Not(IsOk()));
486
487 // Make sure that UDPSocket actually closed the socket.
488 EXPECT_FALSE(socket.is_connected());
489 }
490
491 // Tests that UDPSocket updates the global counter correctly.
TEST_F(UDPSocketTest,LimitAdoptSocket)492 TEST_F(UDPSocketTest, LimitAdoptSocket) {
493 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
494 {
495 // Creating a platform socket does not increase count.
496 auto socketfd =
497 CreatePlatformSocket(ConvertAddressFamily(ADDRESS_FAMILY_IPV4),
498 SOCK_DGRAM, AF_UNIX ? 0 : IPPROTO_UDP);
499 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
500
501 // Simply allocating a UDPSocket does not increase count.
502 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
503 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
504
505 // Calling AdoptOpenedSocket() allocates the socket and increases the global
506 // counter.
507 EXPECT_THAT(socket.AdoptOpenedSocket(ADDRESS_FAMILY_IPV4, socketfd),
508 IsOk());
509 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
510
511 // Connect to an IPv6 address should fail since the socket was created for
512 // IPv4.
513 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
514 Not(IsOk()));
515
516 // That Connect() failed doesn't change the global counter.
517 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
518 }
519 // Finally, destroying UDPSocket decrements the global counter.
520 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
521 }
522
523 // In this test, we verify that connect() on a socket will have the effect
524 // of filtering reads on this socket only to data read from the destination
525 // we connected to.
526 //
527 // The purpose of this test is that some documentation indicates that connect
528 // binds the client's sends to send to a particular server endpoint, but does
529 // not bind the client's reads to only be from that endpoint, and that we need
530 // to always use recvfrom() to disambiguate.
TEST_F(UDPSocketTest,VerifyConnectBindsAddr)531 TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
532 std::string simple_message("hello world!");
533 std::string foreign_message("BAD MESSAGE TO GET!!");
534
535 // Setup the first server to listen.
536 IPEndPoint server1_address(IPAddress::IPv4Localhost(), 0 /* port */);
537 UDPServerSocket server1(nullptr, NetLogSource());
538 ASSERT_THAT(server1.Listen(server1_address), IsOk());
539 // Get the bound port.
540 ASSERT_THAT(server1.GetLocalAddress(&server1_address), IsOk());
541
542 // Setup the second server to listen.
543 IPEndPoint server2_address(IPAddress::IPv4Localhost(), 0 /* port */);
544 UDPServerSocket server2(nullptr, NetLogSource());
545 ASSERT_THAT(server2.Listen(server2_address), IsOk());
546
547 // Setup the client, connected to server 1.
548 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
549 EXPECT_THAT(client.Connect(server1_address), IsOk());
550
551 // Client sends to server1.
552 EXPECT_EQ(simple_message.length(),
553 static_cast<size_t>(WriteSocket(&client, simple_message)));
554
555 // Server1 waits for message.
556 std::string str = RecvFromSocket(&server1);
557 EXPECT_EQ(simple_message, str);
558
559 // Get the client's address.
560 IPEndPoint client_address;
561 EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
562
563 // Server2 sends reply.
564 EXPECT_EQ(foreign_message.length(),
565 static_cast<size_t>(
566 SendToSocket(&server2, foreign_message, client_address)));
567
568 // Server1 sends reply.
569 EXPECT_EQ(simple_message.length(),
570 static_cast<size_t>(
571 SendToSocket(&server1, simple_message, client_address)));
572
573 // Client waits for response.
574 str = ReadSocket(&client);
575 EXPECT_EQ(simple_message, str);
576 }
577
TEST_F(UDPSocketTest,ClientGetLocalPeerAddresses)578 TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
579 struct TestData {
580 std::string remote_address;
581 std::string local_address;
582 bool may_fail;
583 } tests[] = {
584 {"127.0.00.1", "127.0.0.1", false},
585 {"::1", "::1", true},
586 #if !BUILDFLAG(IS_ANDROID) && !BUILDFLAG(IS_IOS)
587 // Addresses below are disabled on Android. See crbug.com/161248
588 // They are also disabled on iOS. See https://crbug.com/523225
589 {"192.168.1.1", "127.0.0.1", false},
590 {"2001:db8:0::42", "::1", true},
591 #endif
592 };
593 for (const auto& test : tests) {
594 SCOPED_TRACE(std::string("Connecting from ") + test.local_address +
595 std::string(" to ") + test.remote_address);
596
597 IPAddress ip_address;
598 EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.remote_address));
599 IPEndPoint remote_address(ip_address, 80);
600 EXPECT_TRUE(ip_address.AssignFromIPLiteral(test.local_address));
601 IPEndPoint local_address(ip_address, 80);
602
603 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
604 NetLogSource());
605 int rv = client.Connect(remote_address);
606 if (test.may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
607 // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
608 // addresses if IPv6 is not configured.
609 continue;
610 }
611
612 EXPECT_LE(ERR_IO_PENDING, rv);
613
614 IPEndPoint fetched_local_address;
615 rv = client.GetLocalAddress(&fetched_local_address);
616 EXPECT_THAT(rv, IsOk());
617
618 // TODO(mbelshe): figure out how to verify the IP and port.
619 // The port is dynamically generated by the udp stack.
620 // The IP is the real IP of the client, not necessarily
621 // loopback.
622 // EXPECT_EQ(local_address.address(), fetched_local_address.address());
623
624 IPEndPoint fetched_remote_address;
625 rv = client.GetPeerAddress(&fetched_remote_address);
626 EXPECT_THAT(rv, IsOk());
627
628 EXPECT_EQ(remote_address, fetched_remote_address);
629 }
630 }
631
TEST_F(UDPSocketTest,ServerGetLocalAddress)632 TEST_F(UDPSocketTest, ServerGetLocalAddress) {
633 IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
634 UDPServerSocket server(nullptr, NetLogSource());
635 int rv = server.Listen(bind_address);
636 EXPECT_THAT(rv, IsOk());
637
638 IPEndPoint local_address;
639 rv = server.GetLocalAddress(&local_address);
640 EXPECT_EQ(rv, 0);
641
642 // Verify that port was allocated.
643 EXPECT_GT(local_address.port(), 0);
644 EXPECT_EQ(local_address.address(), bind_address.address());
645 }
646
TEST_F(UDPSocketTest,ServerGetPeerAddress)647 TEST_F(UDPSocketTest, ServerGetPeerAddress) {
648 IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
649 UDPServerSocket server(nullptr, NetLogSource());
650 int rv = server.Listen(bind_address);
651 EXPECT_THAT(rv, IsOk());
652
653 IPEndPoint peer_address;
654 rv = server.GetPeerAddress(&peer_address);
655 EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
656 }
657
TEST_F(UDPSocketTest,ClientSetDoNotFragment)658 TEST_F(UDPSocketTest, ClientSetDoNotFragment) {
659 for (std::string ip : {"127.0.0.1", "::1"}) {
660 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr,
661 NetLogSource());
662 IPAddress ip_address;
663 EXPECT_TRUE(ip_address.AssignFromIPLiteral(ip));
664 IPEndPoint remote_address(ip_address, 80);
665 int rv = client.Connect(remote_address);
666 // May fail on IPv6 is IPv6 is not configured.
667 if (ip_address.IsIPv6() && rv == ERR_ADDRESS_UNREACHABLE)
668 return;
669 EXPECT_THAT(rv, IsOk());
670
671 rv = client.SetDoNotFragment();
672 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
673 // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
674 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
675 #elif BUILDFLAG(IS_MAC)
676 if (base::mac::MacOSMajorVersion() >= 11) {
677 EXPECT_THAT(rv, IsOk());
678 } else {
679 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
680 }
681 #else
682 EXPECT_THAT(rv, IsOk());
683 #endif
684 }
685 }
686
TEST_F(UDPSocketTest,ServerSetDoNotFragment)687 TEST_F(UDPSocketTest, ServerSetDoNotFragment) {
688 for (std::string ip : {"127.0.0.1", "::1"}) {
689 IPEndPoint bind_address;
690 ASSERT_TRUE(CreateUDPAddress(ip, 0, &bind_address));
691 UDPServerSocket server(nullptr, NetLogSource());
692 int rv = server.Listen(bind_address);
693 // May fail on IPv6 is IPv6 is not configure
694 if (bind_address.address().IsIPv6() &&
695 (rv == ERR_ADDRESS_INVALID || rv == ERR_ADDRESS_UNREACHABLE))
696 return;
697 EXPECT_THAT(rv, IsOk());
698
699 rv = server.SetDoNotFragment();
700 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_FUCHSIA)
701 // TODO(crbug.com/945590): IP_MTU_DISCOVER is not implemented on Fuchsia.
702 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
703 #elif BUILDFLAG(IS_MAC)
704 if (base::mac::MacOSMajorVersion() >= 11) {
705 EXPECT_THAT(rv, IsOk());
706 } else {
707 EXPECT_THAT(rv, IsError(ERR_NOT_IMPLEMENTED));
708 }
709 #else
710 EXPECT_THAT(rv, IsOk());
711 #endif
712 }
713 }
714
715 // Close the socket while read is pending.
TEST_F(UDPSocketTest,CloseWithPendingRead)716 TEST_F(UDPSocketTest, CloseWithPendingRead) {
717 IPEndPoint bind_address(IPAddress::IPv4Localhost(), 0);
718 UDPServerSocket server(nullptr, NetLogSource());
719 int rv = server.Listen(bind_address);
720 EXPECT_THAT(rv, IsOk());
721
722 TestCompletionCallback callback;
723 IPEndPoint from;
724 rv = server.RecvFrom(buffer_.get(), kMaxRead, &from, callback.callback());
725 EXPECT_EQ(rv, ERR_IO_PENDING);
726
727 server.Close();
728
729 EXPECT_FALSE(callback.have_result());
730 }
731
732 // Some Android devices do not support multicast.
733 // The ones supporting multicast need WifiManager.MulitcastLock to enable it.
734 // http://goo.gl/jjAk9
735 #if !BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,JoinMulticastGroup)736 TEST_F(UDPSocketTest, JoinMulticastGroup) {
737 const char kGroup[] = "237.132.100.17";
738
739 IPAddress group_ip;
740 EXPECT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
741 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
742 // OS_FUCHSIA.
743 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
744 IPEndPoint bind_address(IPAddress::AllZeros(group_ip.size()), 0 /* port */);
745 #else
746 IPEndPoint bind_address(group_ip, 0 /* port */);
747 #endif // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
748
749 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
750 EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
751
752 EXPECT_THAT(socket.Bind(bind_address), IsOk());
753 EXPECT_THAT(socket.JoinGroup(group_ip), IsOk());
754 // Joining group multiple times.
755 EXPECT_NE(OK, socket.JoinGroup(group_ip));
756 EXPECT_THAT(socket.LeaveGroup(group_ip), IsOk());
757 // Leaving group multiple times.
758 EXPECT_NE(OK, socket.LeaveGroup(group_ip));
759
760 socket.Close();
761 }
762
763 // TODO(https://crbug.com/947115): failing on device on iOS 12.2.
764 // TODO(https://crbug.com/1227554): flaky on Mac 11.
765 #if BUILDFLAG(IS_IOS) || BUILDFLAG(IS_MAC)
766 #define MAYBE_SharedMulticastAddress DISABLED_SharedMulticastAddress
767 #else
768 #define MAYBE_SharedMulticastAddress SharedMulticastAddress
769 #endif
TEST_F(UDPSocketTest,MAYBE_SharedMulticastAddress)770 TEST_F(UDPSocketTest, MAYBE_SharedMulticastAddress) {
771 const char kGroup[] = "224.0.0.251";
772
773 IPAddress group_ip;
774 ASSERT_TRUE(group_ip.AssignFromIPLiteral(kGroup));
775 // TODO(https://github.com/google/gvisor/issues/3839): don't guard on
776 // OS_FUCHSIA.
777 #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
778 IPEndPoint receive_address(IPAddress::AllZeros(group_ip.size()),
779 0 /* port */);
780 #else
781 IPEndPoint receive_address(group_ip, 0 /* port */);
782 #endif // BUILDFLAG(IS_WIN) || BUILDFLAG(IS_FUCHSIA)
783
784 NetworkInterfaceList interfaces;
785 ASSERT_TRUE(GetNetworkList(&interfaces, 0));
786 // The test fails with the Hyper-V switch interface (on the host side).
787 interfaces.erase(std::remove_if(interfaces.begin(), interfaces.end(),
788 [](const auto& iface) {
789 return iface.friendly_name.rfind(
790 "vEthernet", 0) == 0;
791 }),
792 interfaces.end());
793 ASSERT_FALSE(interfaces.empty());
794
795 // Setup first receiving socket.
796 UDPServerSocket socket1(nullptr, NetLogSource());
797 socket1.AllowAddressSharingForMulticast();
798 ASSERT_THAT(socket1.SetMulticastInterface(interfaces[0].interface_index),
799 IsOk());
800 ASSERT_THAT(socket1.Listen(receive_address), IsOk());
801 ASSERT_THAT(socket1.JoinGroup(group_ip), IsOk());
802 // Get the bound port.
803 ASSERT_THAT(socket1.GetLocalAddress(&receive_address), IsOk());
804
805 // Setup second receiving socket.
806 UDPServerSocket socket2(nullptr, NetLogSource());
807 socket2.AllowAddressSharingForMulticast(), IsOk();
808 ASSERT_THAT(socket2.SetMulticastInterface(interfaces[0].interface_index),
809 IsOk());
810 ASSERT_THAT(socket2.Listen(receive_address), IsOk());
811 ASSERT_THAT(socket2.JoinGroup(group_ip), IsOk());
812
813 // Setup client socket.
814 IPEndPoint send_address(group_ip, receive_address.port());
815 UDPClientSocket client_socket(DatagramSocket::DEFAULT_BIND, nullptr,
816 NetLogSource());
817 ASSERT_THAT(client_socket.Connect(send_address), IsOk());
818
819 #if !BUILDFLAG(IS_CHROMEOS_ASH)
820 // Send a message via the multicast group. That message is expected be be
821 // received by both receving sockets.
822 //
823 // Skip on ChromeOS where it's known to sometimes not work.
824 // TODO(crbug.com/898964): If possible, fix and reenable.
825 const char kMessage[] = "hello!";
826 ASSERT_GE(WriteSocket(&client_socket, kMessage), 0);
827 EXPECT_EQ(kMessage, RecvFromSocket(&socket1));
828 EXPECT_EQ(kMessage, RecvFromSocket(&socket2));
829 #endif // !BUILDFLAG(IS_CHROMEOS_ASH)
830 }
831 #endif // !BUILDFLAG(IS_ANDROID)
832
TEST_F(UDPSocketTest,MulticastOptions)833 TEST_F(UDPSocketTest, MulticastOptions) {
834 IPEndPoint bind_address;
835 ASSERT_TRUE(CreateUDPAddress("0.0.0.0", 0 /* port */, &bind_address));
836
837 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
838 // Before binding.
839 EXPECT_THAT(socket.SetMulticastLoopbackMode(false), IsOk());
840 EXPECT_THAT(socket.SetMulticastLoopbackMode(true), IsOk());
841 EXPECT_THAT(socket.SetMulticastTimeToLive(0), IsOk());
842 EXPECT_THAT(socket.SetMulticastTimeToLive(3), IsOk());
843 EXPECT_NE(OK, socket.SetMulticastTimeToLive(-1));
844 EXPECT_THAT(socket.SetMulticastInterface(0), IsOk());
845
846 EXPECT_THAT(socket.Open(bind_address.GetFamily()), IsOk());
847 EXPECT_THAT(socket.Bind(bind_address), IsOk());
848
849 EXPECT_NE(OK, socket.SetMulticastLoopbackMode(false));
850 EXPECT_NE(OK, socket.SetMulticastTimeToLive(0));
851 EXPECT_NE(OK, socket.SetMulticastInterface(0));
852
853 socket.Close();
854 }
855
856 // Checking that DSCP bits are set correctly is difficult,
857 // but let's check that the code doesn't crash at least.
TEST_F(UDPSocketTest,SetDSCP)858 TEST_F(UDPSocketTest, SetDSCP) {
859 // Setup the server to listen.
860 IPEndPoint bind_address;
861 UDPSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
862 // We need a real IP, but we won't actually send anything to it.
863 ASSERT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
864 int rv = client.Open(bind_address.GetFamily());
865 EXPECT_THAT(rv, IsOk());
866
867 rv = client.Connect(bind_address);
868 if (rv != OK) {
869 // Let's try localhost then.
870 bind_address = IPEndPoint(IPAddress::IPv4Localhost(), 9999);
871 rv = client.Connect(bind_address);
872 }
873 EXPECT_THAT(rv, IsOk());
874
875 client.SetDiffServCodePoint(DSCP_NO_CHANGE);
876 client.SetDiffServCodePoint(DSCP_AF41);
877 client.SetDiffServCodePoint(DSCP_DEFAULT);
878 client.SetDiffServCodePoint(DSCP_CS2);
879 client.SetDiffServCodePoint(DSCP_NO_CHANGE);
880 client.SetDiffServCodePoint(DSCP_DEFAULT);
881 client.Close();
882 }
883
884 // Send DSCP + ECN marked packets from server to client and verify the TOS
885 // bytes that arrive.
TEST_F(UDPSocketTest,VerifyDscpAndEcnExchange)886 TEST_F(UDPSocketTest, VerifyDscpAndEcnExchange) {
887 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0);
888 UDPServerSocket server(nullptr, NetLogSource());
889 server.AllowAddressReuse();
890 ASSERT_THAT(server.Listen(server_address), IsOk());
891 // Get bound port.
892 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
893 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
894 client.Connect(server_address);
895 EXPECT_EQ(client.SetRecvTos(), 0);
896 IPEndPoint client_address;
897 client.GetLocalAddress(&client_address);
898
899 EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_ECT1), 0);
900 std::string first_message = "foobar";
901 EXPECT_EQ(SendToSocket(&server, first_message, client_address),
902 static_cast<int>(first_message.length()));
903 EXPECT_EQ(ReadSocket(&client, DSCP_AF41, ECN_ECT1), first_message.data());
904
905 std::string second_message = "foo";
906 EXPECT_EQ(server.SetTos(DSCP_CS2, ECN_ECT0), 0);
907 EXPECT_EQ(SendToSocket(&server, second_message, client_address),
908 static_cast<int>(second_message.length()));
909 EXPECT_EQ(ReadSocket(&client, DSCP_CS2, ECN_ECT0), second_message.data());
910
911 #if BUILDFLAG(IS_WIN)
912 // The Windows sendmsg API does not allow setting ECN_CE as the outgoing mark.
913 EcnCodePoint final_ecn = ECN_ECT1;
914 #else
915 EcnCodePoint final_ecn = ECN_CE;
916 #endif
917
918 EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, final_ecn), 0);
919 EXPECT_EQ(SendToSocket(&server, second_message, client_address),
920 static_cast<int>(second_message.length()));
921 EXPECT_EQ(ReadSocket(&client, DSCP_CS2, final_ecn), second_message.data());
922
923 EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_NO_CHANGE), 0);
924 EXPECT_EQ(SendToSocket(&server, second_message, client_address),
925 static_cast<int>(second_message.length()));
926 EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data());
927
928 EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, ECN_NO_CHANGE), 0);
929 EXPECT_EQ(SendToSocket(&server, second_message, client_address),
930 static_cast<int>(second_message.length()));
931 EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data());
932
933 server.Close();
934 client.Close();
935 }
936
937 // For windows, test with Nonblocking sockets. For other platforms, this test
938 // is identical to VerifyDscpAndEcnExchange, above.
TEST_F(UDPSocketTest,VerifyDscpAndEcnExchangeNonBlocking)939 TEST_F(UDPSocketTest, VerifyDscpAndEcnExchangeNonBlocking) {
940 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0);
941 UDPServerSocket server(nullptr, NetLogSource());
942 server.UseNonBlockingIO();
943 server.AllowAddressReuse();
944 ASSERT_THAT(server.Listen(server_address), IsOk());
945 // Get bound port.
946 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
947 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
948 client.UseNonBlockingIO();
949 client.Connect(server_address);
950 EXPECT_EQ(client.SetRecvTos(), 0);
951 IPEndPoint client_address;
952 client.GetLocalAddress(&client_address);
953
954 EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_ECT1), 0);
955 std::string first_message = "foobar";
956 EXPECT_EQ(SendToSocket(&server, first_message, client_address),
957 static_cast<int>(first_message.length()));
958 EXPECT_EQ(ReadSocket(&client, DSCP_AF41, ECN_ECT1), first_message.data());
959
960 std::string second_message = "foo";
961 EXPECT_EQ(server.SetTos(DSCP_CS2, ECN_ECT0), 0);
962 EXPECT_EQ(SendToSocket(&server, second_message, client_address),
963 static_cast<int>(second_message.length()));
964 EXPECT_EQ(ReadSocket(&client, DSCP_CS2, ECN_ECT0), second_message.data());
965
966 // The Windows sendmsg API does not allow setting ECN_CE as the outgoing mark.
967 EcnCodePoint final_ecn = ECN_ECT1;
968
969 EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, final_ecn), 0);
970 EXPECT_EQ(SendToSocket(&server, second_message, client_address),
971 static_cast<int>(second_message.length()));
972 EXPECT_EQ(ReadSocket(&client, DSCP_CS2, final_ecn), second_message.data());
973
974 EXPECT_EQ(server.SetTos(DSCP_AF41, ECN_NO_CHANGE), 0);
975 EXPECT_EQ(SendToSocket(&server, second_message, client_address),
976 static_cast<int>(second_message.length()));
977 EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data());
978
979 EXPECT_EQ(server.SetTos(DSCP_NO_CHANGE, ECN_NO_CHANGE), 0);
980 EXPECT_EQ(SendToSocket(&server, second_message, client_address),
981 static_cast<int>(second_message.length()));
982 EXPECT_EQ(ReadSocket(&client, DSCP_AF41, final_ecn), second_message.data());
983
984 server.Close();
985 client.Close();
986 }
987
TEST_F(UDPSocketTest,ConnectUsingNetwork)988 TEST_F(UDPSocketTest, ConnectUsingNetwork) {
989 // The specific value of this address doesn't really matter, and no
990 // server needs to be running here. The test only needs to call
991 // ConnectUsingNetwork() and won't send any datagrams.
992 const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
993 const handles::NetworkHandle wrong_network_handle = 65536;
994 #if BUILDFLAG(IS_ANDROID)
995 NetworkChangeNotifierFactoryAndroid ncn_factory;
996 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
997 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
998 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
999 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
1000
1001 {
1002 // Connecting using a not existing network should fail but not report
1003 // ERR_NOT_IMPLEMENTED when network handles are supported.
1004 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
1005 NetLogSource());
1006 int rv =
1007 socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address);
1008 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
1009 EXPECT_NE(OK, rv);
1010 EXPECT_NE(wrong_network_handle, socket.GetBoundNetwork());
1011 }
1012
1013 {
1014 // Connecting using an existing network should succeed when
1015 // NetworkChangeNotifier returns a valid default network.
1016 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
1017 NetLogSource());
1018 const handles::NetworkHandle network_handle =
1019 NetworkChangeNotifier::GetDefaultNetwork();
1020 if (network_handle != handles::kInvalidNetworkHandle) {
1021 EXPECT_EQ(
1022 OK, socket.ConnectUsingNetwork(network_handle, fake_server_address));
1023 EXPECT_EQ(network_handle, socket.GetBoundNetwork());
1024 }
1025 }
1026 #else
1027 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
1028 EXPECT_EQ(
1029 ERR_NOT_IMPLEMENTED,
1030 socket.ConnectUsingNetwork(wrong_network_handle, fake_server_address));
1031 #endif // BUILDFLAG(IS_ANDROID)
1032 }
1033
TEST_F(UDPSocketTest,ConnectUsingNetworkAsync)1034 TEST_F(UDPSocketTest, ConnectUsingNetworkAsync) {
1035 // The specific value of this address doesn't really matter, and no
1036 // server needs to be running here. The test only needs to call
1037 // ConnectUsingNetwork() and won't send any datagrams.
1038 const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
1039 const handles::NetworkHandle wrong_network_handle = 65536;
1040 #if BUILDFLAG(IS_ANDROID)
1041 NetworkChangeNotifierFactoryAndroid ncn_factory;
1042 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
1043 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
1044 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
1045 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
1046
1047 {
1048 // Connecting using a not existing network should fail but not report
1049 // ERR_NOT_IMPLEMENTED when network handles are supported.
1050 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
1051 NetLogSource());
1052 TestCompletionCallback callback;
1053 int rv = socket.ConnectUsingNetworkAsync(
1054 wrong_network_handle, fake_server_address, callback.callback());
1055
1056 if (rv == ERR_IO_PENDING) {
1057 rv = callback.WaitForResult();
1058 }
1059 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
1060 EXPECT_NE(OK, rv);
1061 }
1062
1063 {
1064 // Connecting using an existing network should succeed when
1065 // NetworkChangeNotifier returns a valid default network.
1066 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr,
1067 NetLogSource());
1068 TestCompletionCallback callback;
1069 const handles::NetworkHandle network_handle =
1070 NetworkChangeNotifier::GetDefaultNetwork();
1071 if (network_handle != handles::kInvalidNetworkHandle) {
1072 int rv = socket.ConnectUsingNetworkAsync(
1073 network_handle, fake_server_address, callback.callback());
1074 if (rv == ERR_IO_PENDING) {
1075 rv = callback.WaitForResult();
1076 }
1077 EXPECT_EQ(OK, rv);
1078 EXPECT_EQ(network_handle, socket.GetBoundNetwork());
1079 }
1080 }
1081 #else
1082 UDPClientSocket socket(DatagramSocket::RANDOM_BIND, nullptr, NetLogSource());
1083 TestCompletionCallback callback;
1084 EXPECT_EQ(ERR_NOT_IMPLEMENTED, socket.ConnectUsingNetworkAsync(
1085 wrong_network_handle, fake_server_address,
1086 callback.callback()));
1087 #endif // BUILDFLAG(IS_ANDROID)
1088 }
1089
1090 } // namespace
1091
1092 #if BUILDFLAG(IS_WIN)
1093
1094 namespace {
1095
1096 const HANDLE kFakeHandle1 = (HANDLE)12;
1097 const HANDLE kFakeHandle2 = (HANDLE)13;
1098
1099 const QOS_FLOWID kFakeFlowId1 = (QOS_FLOWID)27;
1100 const QOS_FLOWID kFakeFlowId2 = (QOS_FLOWID)38;
1101
1102 class TestUDPSocketWin : public UDPSocketWin {
1103 public:
TestUDPSocketWin(QwaveApi * qos,DatagramSocket::BindType bind_type,net::NetLog * net_log,const net::NetLogSource & source)1104 TestUDPSocketWin(QwaveApi* qos,
1105 DatagramSocket::BindType bind_type,
1106 net::NetLog* net_log,
1107 const net::NetLogSource& source)
1108 : UDPSocketWin(bind_type, net_log, source), qos_(qos) {}
1109
1110 TestUDPSocketWin(const TestUDPSocketWin&) = delete;
1111 TestUDPSocketWin& operator=(const TestUDPSocketWin&) = delete;
1112
1113 // Overriding GetQwaveApi causes the test class to use the injected mock
1114 // QwaveApi instance instead of the singleton.
GetQwaveApi() const1115 QwaveApi* GetQwaveApi() const override { return qos_; }
1116
1117 private:
1118 raw_ptr<QwaveApi> qos_;
1119 };
1120
1121 class MockQwaveApi : public QwaveApi {
1122 public:
1123 MOCK_CONST_METHOD0(qwave_supported, bool());
1124 MOCK_METHOD0(OnFatalError, void());
1125 MOCK_METHOD2(CreateHandle, BOOL(PQOS_VERSION version, PHANDLE handle));
1126 MOCK_METHOD1(CloseHandle, BOOL(HANDLE handle));
1127 MOCK_METHOD6(AddSocketToFlow,
1128 BOOL(HANDLE handle,
1129 SOCKET socket,
1130 PSOCKADDR addr,
1131 QOS_TRAFFIC_TYPE traffic_type,
1132 DWORD flags,
1133 PQOS_FLOWID flow_id));
1134
1135 MOCK_METHOD4(
1136 RemoveSocketFromFlow,
1137 BOOL(HANDLE handle, SOCKET socket, QOS_FLOWID flow_id, DWORD reserved));
1138 MOCK_METHOD7(SetFlow,
1139 BOOL(HANDLE handle,
1140 QOS_FLOWID flow_id,
1141 QOS_SET_FLOW op,
1142 ULONG size,
1143 PVOID data,
1144 DWORD reserved,
1145 LPOVERLAPPED overlapped));
1146 };
1147
OpenedDscpTestClient(QwaveApi * api,IPEndPoint bind_address)1148 std::unique_ptr<UDPSocket> OpenedDscpTestClient(QwaveApi* api,
1149 IPEndPoint bind_address) {
1150 auto client = std::make_unique<TestUDPSocketWin>(
1151 api, DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1152 int rv = client->Open(bind_address.GetFamily());
1153 EXPECT_THAT(rv, IsOk());
1154
1155 return client;
1156 }
1157
ConnectedDscpTestClient(QwaveApi * api)1158 std::unique_ptr<UDPSocket> ConnectedDscpTestClient(QwaveApi* api) {
1159 IPEndPoint bind_address;
1160 // We need a real IP, but we won't actually send anything to it.
1161 EXPECT_TRUE(CreateUDPAddress("8.8.8.8", 9999, &bind_address));
1162 auto client = OpenedDscpTestClient(api, bind_address);
1163 EXPECT_THAT(client->Connect(bind_address), IsOk());
1164 return client;
1165 }
1166
UnconnectedDscpTestClient(QwaveApi * api)1167 std::unique_ptr<UDPSocket> UnconnectedDscpTestClient(QwaveApi* api) {
1168 IPEndPoint bind_address;
1169 EXPECT_TRUE(CreateUDPAddress("0.0.0.0", 9999, &bind_address));
1170 auto client = OpenedDscpTestClient(api, bind_address);
1171 EXPECT_THAT(client->Bind(bind_address), IsOk());
1172 return client;
1173 }
1174
1175 } // namespace
1176
1177 using ::testing::Return;
1178 using ::testing::SetArgPointee;
1179 using ::testing::_;
1180
TEST_F(UDPSocketTest,SetDSCPNoopIfPassedNoChange)1181 TEST_F(UDPSocketTest, SetDSCPNoopIfPassedNoChange) {
1182 MockQwaveApi api;
1183 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1184
1185 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1186 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1187 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_NO_CHANGE), IsOk());
1188 }
1189
TEST_F(UDPSocketTest,SetDSCPFailsIfQOSDoesntLink)1190 TEST_F(UDPSocketTest, SetDSCPFailsIfQOSDoesntLink) {
1191 MockQwaveApi api;
1192 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1193 EXPECT_CALL(api, CreateHandle(_, _)).Times(0);
1194
1195 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1196 EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1197 }
1198
TEST_F(UDPSocketTest,SetDSCPFailsIfHandleCantBeCreated)1199 TEST_F(UDPSocketTest, SetDSCPFailsIfHandleCantBeCreated) {
1200 MockQwaveApi api;
1201 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1202 EXPECT_CALL(api, CreateHandle(_, _)).WillOnce(Return(false));
1203 EXPECT_CALL(api, OnFatalError()).Times(1);
1204
1205 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1206 EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1207
1208 RunUntilIdle();
1209
1210 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(false));
1211 EXPECT_EQ(ERR_NOT_IMPLEMENTED, client->SetDiffServCodePoint(DSCP_AF41));
1212 }
1213
1214 MATCHER_P(DscpPointee, dscp, "") {
1215 return *(DWORD*)arg == (DWORD)dscp;
1216 }
1217
TEST_F(UDPSocketTest,ConnectedSocketDelayedInitAndUpdate)1218 TEST_F(UDPSocketTest, ConnectedSocketDelayedInitAndUpdate) {
1219 MockQwaveApi api;
1220 std::unique_ptr<UDPSocket> client = ConnectedDscpTestClient(&api);
1221 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1222 EXPECT_CALL(api, CreateHandle(_, _))
1223 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1224
1225 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1226 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1227 EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1228
1229 // First set on connected sockets will fail since init is async and
1230 // we haven't given the runloop a chance to execute the callback.
1231 EXPECT_EQ(ERR_INVALID_HANDLE, client->SetDiffServCodePoint(DSCP_AF41));
1232 RunUntilIdle();
1233 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1234
1235 // New dscp value should reset the flow.
1236 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1237 EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeBestEffort, _, _))
1238 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1239 EXPECT_CALL(api, SetFlow(_, _, QOSSetOutgoingDSCPValue, _,
1240 DscpPointee(DSCP_DEFAULT), _, _));
1241 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_DEFAULT), IsOk());
1242
1243 // Called from DscpManager destructor.
1244 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1245 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1246 }
1247
TEST_F(UDPSocketTest,UnonnectedSocketDelayedInitAndUpdate)1248 TEST_F(UDPSocketTest, UnonnectedSocketDelayedInitAndUpdate) {
1249 MockQwaveApi api;
1250 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1251 EXPECT_CALL(api, CreateHandle(_, _))
1252 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1253
1254 // CreateHandle won't have completed yet. Set passes.
1255 std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1256 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1257
1258 RunUntilIdle();
1259 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF42), IsOk());
1260
1261 // Called from DscpManager destructor.
1262 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1263 }
1264
1265 // TODO(zstein): Mocking out DscpManager might be simpler here
1266 // (just verify that DscpManager::Set and DscpManager::PrepareForSend are
1267 // called).
TEST_F(UDPSocketTest,SendToCallsQwaveApis)1268 TEST_F(UDPSocketTest, SendToCallsQwaveApis) {
1269 MockQwaveApi api;
1270 std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1271 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1272 EXPECT_CALL(api, CreateHandle(_, _))
1273 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1274 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_AF41), IsOk());
1275 RunUntilIdle();
1276
1277 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _))
1278 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1279 EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _));
1280 std::string simple_message("hello world");
1281 IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1282 int rv = SendToSocket(client.get(), simple_message, server_address);
1283 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1284
1285 // TODO(zstein): Move to second test case (Qwave APIs called once per address)
1286 rv = SendToSocket(client.get(), simple_message, server_address);
1287 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1288
1289 // TODO(zstein): Move to third test case (Qwave APIs called for each
1290 // destination address).
1291 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(true));
1292 IPEndPoint server_address2(IPAddress::IPv4Localhost(), 9439);
1293
1294 rv = SendToSocket(client.get(), simple_message, server_address2);
1295 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1296
1297 // Called from DscpManager destructor.
1298 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, _, _));
1299 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1300 }
1301
TEST_F(UDPSocketTest,SendToCallsApisAfterDeferredInit)1302 TEST_F(UDPSocketTest, SendToCallsApisAfterDeferredInit) {
1303 MockQwaveApi api;
1304 std::unique_ptr<UDPSocket> client = UnconnectedDscpTestClient(&api);
1305 EXPECT_CALL(api, qwave_supported()).WillRepeatedly(Return(true));
1306 EXPECT_CALL(api, CreateHandle(_, _))
1307 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1308
1309 // SetDiffServCodepoint works even if qos api hasn't finished initing.
1310 EXPECT_THAT(client->SetDiffServCodePoint(DSCP_CS7), IsOk());
1311
1312 std::string simple_message("hello world");
1313 IPEndPoint server_address(IPAddress::IPv4Localhost(), 9438);
1314
1315 // SendTo works, but doesn't yet apply TOS
1316 EXPECT_CALL(api, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1317 int rv = SendToSocket(client.get(), simple_message, server_address);
1318 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1319
1320 RunUntilIdle();
1321 // Now we're initialized, SendTo triggers qos calls with correct codepoint.
1322 EXPECT_CALL(api, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1323 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1324 EXPECT_CALL(api, SetFlow(_, _, _, _, _, _, _)).WillOnce(Return(true));
1325 rv = SendToSocket(client.get(), simple_message, server_address);
1326 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1327
1328 // Called from DscpManager destructor.
1329 EXPECT_CALL(api, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1330 EXPECT_CALL(api, CloseHandle(kFakeHandle1));
1331 }
1332
1333 class DscpManagerTest : public TestWithTaskEnvironment {
1334 protected:
DscpManagerTest()1335 DscpManagerTest() {
1336 EXPECT_CALL(api_, qwave_supported()).WillRepeatedly(Return(true));
1337 EXPECT_CALL(api_, CreateHandle(_, _))
1338 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle1), Return(true)));
1339 dscp_manager_ = std::make_unique<DscpManager>(&api_, INVALID_SOCKET);
1340
1341 CreateUDPAddress("1.2.3.4", 9001, &address1_);
1342 CreateUDPAddress("1234:5678:90ab:cdef:1234:5678:90ab:cdef", 9002,
1343 &address2_);
1344 }
1345
1346 MockQwaveApi api_;
1347 std::unique_ptr<DscpManager> dscp_manager_;
1348
1349 IPEndPoint address1_;
1350 IPEndPoint address2_;
1351 };
1352
TEST_F(DscpManagerTest,PrepareForSendIsNoopIfNoSet)1353 TEST_F(DscpManagerTest, PrepareForSendIsNoopIfNoSet) {
1354 RunUntilIdle();
1355 dscp_manager_->PrepareForSend(address1_);
1356 }
1357
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisAfterSet)1358 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisAfterSet) {
1359 RunUntilIdle();
1360 dscp_manager_->Set(DSCP_CS2);
1361
1362 // AddSocketToFlow should be called for each address.
1363 // SetFlow should only be called when the flow is first created.
1364 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1365 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1366 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1367 dscp_manager_->PrepareForSend(address1_);
1368
1369 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1370 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1371 EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1372 dscp_manager_->PrepareForSend(address2_);
1373
1374 // Called from DscpManager destructor.
1375 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1376 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1377 }
1378
TEST_F(DscpManagerTest,PrepareForSendCallsQwaveApisOncePerAddress)1379 TEST_F(DscpManagerTest, PrepareForSendCallsQwaveApisOncePerAddress) {
1380 RunUntilIdle();
1381 dscp_manager_->Set(DSCP_CS2);
1382
1383 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1384 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1385 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1386 dscp_manager_->PrepareForSend(address1_);
1387 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).Times(0);
1388 EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1389 dscp_manager_->PrepareForSend(address1_);
1390
1391 // Called from DscpManager destructor.
1392 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _));
1393 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1394 }
1395
TEST_F(DscpManagerTest,SetDestroysExistingFlow)1396 TEST_F(DscpManagerTest, SetDestroysExistingFlow) {
1397 RunUntilIdle();
1398 dscp_manager_->Set(DSCP_CS2);
1399
1400 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1401 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1402 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _));
1403 dscp_manager_->PrepareForSend(address1_);
1404
1405 // Calling Set should destroy the existing flow.
1406 // TODO(zstein): Verify that RemoveSocketFromFlow with no address
1407 // destroys the flow for all destinations.
1408 EXPECT_CALL(api_, RemoveSocketFromFlow(_, NULL, kFakeFlowId1, _));
1409 dscp_manager_->Set(DSCP_CS5);
1410
1411 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1412 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1413 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _));
1414 dscp_manager_->PrepareForSend(address1_);
1415
1416 // Called from DscpManager destructor.
1417 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1418 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1419 }
1420
TEST_F(DscpManagerTest,SocketReAddedOnRecreateHandle)1421 TEST_F(DscpManagerTest, SocketReAddedOnRecreateHandle) {
1422 RunUntilIdle();
1423 dscp_manager_->Set(DSCP_CS2);
1424
1425 // First Set and Send work fine.
1426 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _))
1427 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId1), Return(true)));
1428 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId1, _, _, _, _, _))
1429 .WillOnce(Return(true));
1430 EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1431
1432 // Make Second flow operation fail (requires resetting the codepoint).
1433 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId1, _))
1434 .WillOnce(Return(true));
1435 dscp_manager_->Set(DSCP_CS7);
1436
1437 auto error = std::make_unique<base::ScopedClearLastError>();
1438 ::SetLastError(ERROR_DEVICE_REINITIALIZATION_NEEDED);
1439 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, _, _, _)).WillOnce(Return(false));
1440 EXPECT_CALL(api_, SetFlow(_, _, _, _, _, _, _)).Times(0);
1441 EXPECT_CALL(api_, CloseHandle(kFakeHandle1));
1442 EXPECT_CALL(api_, CreateHandle(_, _))
1443 .WillOnce(DoAll(SetArgPointee<1>(kFakeHandle2), Return(true)));
1444 EXPECT_EQ(ERR_INVALID_HANDLE, dscp_manager_->PrepareForSend(address1_));
1445 error = nullptr;
1446 RunUntilIdle();
1447
1448 // Next Send should work fine, without requiring another Set
1449 EXPECT_CALL(api_, AddSocketToFlow(_, _, _, QOSTrafficTypeControl, _, _))
1450 .WillOnce(DoAll(SetArgPointee<5>(kFakeFlowId2), Return(true)));
1451 EXPECT_CALL(api_, SetFlow(_, kFakeFlowId2, _, _, _, _, _))
1452 .WillOnce(Return(true));
1453 EXPECT_THAT(dscp_manager_->PrepareForSend(address1_), IsOk());
1454
1455 // Called from DscpManager destructor.
1456 EXPECT_CALL(api_, RemoveSocketFromFlow(_, _, kFakeFlowId2, _));
1457 EXPECT_CALL(api_, CloseHandle(kFakeHandle2));
1458 }
1459
1460 #endif
1461
TEST_F(UDPSocketTest,ReadWithSocketOptimization)1462 TEST_F(UDPSocketTest, ReadWithSocketOptimization) {
1463 std::string simple_message("hello world!");
1464
1465 // Setup the server to listen.
1466 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1467 UDPServerSocket server(nullptr, NetLogSource());
1468 server.AllowAddressReuse();
1469 ASSERT_THAT(server.Listen(server_address), IsOk());
1470 // Get bound port.
1471 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1472
1473 // Setup the client, enable experimental optimization and connected to the
1474 // server.
1475 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1476 client.EnableRecvOptimization();
1477 EXPECT_THAT(client.Connect(server_address), IsOk());
1478
1479 // Get the client's address.
1480 IPEndPoint client_address;
1481 EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1482
1483 // Server sends the message to the client.
1484 EXPECT_EQ(simple_message.length(),
1485 static_cast<size_t>(
1486 SendToSocket(&server, simple_message, client_address)));
1487
1488 // Client receives the message.
1489 std::string str = ReadSocket(&client);
1490 EXPECT_EQ(simple_message, str);
1491
1492 server.Close();
1493 client.Close();
1494 }
1495
1496 // Tests that read from a socket correctly returns
1497 // |ERR_MSG_TOO_BIG| when the buffer is too small and
1498 // returns the actual message when it fits the buffer.
1499 // For the optimized path, the buffer size should be at least
1500 // 1 byte greater than the message.
TEST_F(UDPSocketTest,ReadWithSocketOptimizationTruncation)1501 TEST_F(UDPSocketTest, ReadWithSocketOptimizationTruncation) {
1502 std::string too_long_message(kMaxRead + 1, 'A');
1503 std::string right_length_message(kMaxRead - 1, 'B');
1504 std::string exact_length_message(kMaxRead, 'C');
1505
1506 // Setup the server to listen.
1507 IPEndPoint server_address(IPAddress::IPv4Localhost(), 0 /* port */);
1508 UDPServerSocket server(nullptr, NetLogSource());
1509 server.AllowAddressReuse();
1510 ASSERT_THAT(server.Listen(server_address), IsOk());
1511 // Get bound port.
1512 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1513
1514 // Setup the client, enable experimental optimization and connected to the
1515 // server.
1516 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1517 client.EnableRecvOptimization();
1518 EXPECT_THAT(client.Connect(server_address), IsOk());
1519
1520 // Get the client's address.
1521 IPEndPoint client_address;
1522 EXPECT_THAT(client.GetLocalAddress(&client_address), IsOk());
1523
1524 // Send messages to the client.
1525 EXPECT_EQ(too_long_message.length(),
1526 static_cast<size_t>(
1527 SendToSocket(&server, too_long_message, client_address)));
1528 EXPECT_EQ(right_length_message.length(),
1529 static_cast<size_t>(
1530 SendToSocket(&server, right_length_message, client_address)));
1531 EXPECT_EQ(exact_length_message.length(),
1532 static_cast<size_t>(
1533 SendToSocket(&server, exact_length_message, client_address)));
1534
1535 // Client receives the messages.
1536
1537 // 1. The first message is |too_long_message|. Its size exceeds the buffer.
1538 // In that case, the client is expected to get |ERR_MSG_TOO_BIG| when the
1539 // data is read.
1540 TestCompletionCallback callback;
1541 int rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1542 EXPECT_EQ(ERR_MSG_TOO_BIG, callback.GetResult(rv));
1543 EXPECT_EQ(client.GetLastTos().dscp, DSCP_DEFAULT);
1544 EXPECT_EQ(client.GetLastTos().ecn, ECN_DEFAULT);
1545
1546 // 2. The second message is |right_length_message|. Its size is
1547 // one byte smaller than the size of the buffer. In that case, the client
1548 // is expected to read the whole message successfully.
1549 rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1550 rv = callback.GetResult(rv);
1551 EXPECT_EQ(static_cast<int>(right_length_message.length()), rv);
1552 EXPECT_EQ(right_length_message, std::string(buffer_->data(), rv));
1553 EXPECT_EQ(client.GetLastTos().dscp, DSCP_DEFAULT);
1554 EXPECT_EQ(client.GetLastTos().ecn, ECN_DEFAULT);
1555
1556 // 3. The third message is |exact_length_message|. Its size is equal to
1557 // the read buffer size. In that case, the client expects to get
1558 // |ERR_MSG_TOO_BIG| when the socket is read. Internally, the optimized
1559 // path uses read() system call that requires one extra byte to detect
1560 // truncated messages; therefore, messages that fill the buffer exactly
1561 // are considered truncated.
1562 // The optimization is only enabled on POSIX platforms. On Windows,
1563 // the optimization is turned off; therefore, the client
1564 // should be able to read the whole message without encountering
1565 // |ERR_MSG_TOO_BIG|.
1566 rv = client.Read(buffer_.get(), kMaxRead, callback.callback());
1567 rv = callback.GetResult(rv);
1568 EXPECT_EQ(client.GetLastTos().dscp, DSCP_DEFAULT);
1569 EXPECT_EQ(client.GetLastTos().ecn, ECN_DEFAULT);
1570 #if BUILDFLAG(IS_POSIX)
1571 EXPECT_EQ(ERR_MSG_TOO_BIG, rv);
1572 #else
1573 EXPECT_EQ(static_cast<int>(exact_length_message.length()), rv);
1574 EXPECT_EQ(exact_length_message, std::string(buffer_->data(), rv));
1575 #endif
1576 server.Close();
1577 client.Close();
1578 }
1579
1580 // On Android, where socket tagging is supported, verify that UDPSocket::Tag
1581 // works as expected.
1582 #if BUILDFLAG(IS_ANDROID)
TEST_F(UDPSocketTest,Tag)1583 TEST_F(UDPSocketTest, Tag) {
1584 if (!CanGetTaggedBytes()) {
1585 DVLOG(0) << "Skipping test - GetTaggedBytes unsupported.";
1586 return;
1587 }
1588
1589 UDPServerSocket server(nullptr, NetLogSource());
1590 ASSERT_THAT(server.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0)), IsOk());
1591 IPEndPoint server_address;
1592 ASSERT_THAT(server.GetLocalAddress(&server_address), IsOk());
1593
1594 UDPClientSocket client(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1595 ASSERT_THAT(client.Connect(server_address), IsOk());
1596
1597 // Verify UDP packets are tagged and counted properly.
1598 int32_t tag_val1 = 0x12345678;
1599 uint64_t old_traffic = GetTaggedBytes(tag_val1);
1600 SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
1601 client.ApplySocketTag(tag1);
1602 // Client sends to the server.
1603 std::string simple_message("hello world!");
1604 int rv = WriteSocket(&client, simple_message);
1605 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1606 // Server waits for message.
1607 std::string str = RecvFromSocket(&server);
1608 EXPECT_EQ(simple_message, str);
1609 // Server echoes reply.
1610 rv = SendToSocket(&server, simple_message);
1611 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1612 // Client waits for response.
1613 str = ReadSocket(&client);
1614 EXPECT_EQ(simple_message, str);
1615 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1616
1617 // Verify socket can be retagged with a new value and the current process's
1618 // UID.
1619 int32_t tag_val2 = 0x87654321;
1620 old_traffic = GetTaggedBytes(tag_val2);
1621 SocketTag tag2(getuid(), tag_val2);
1622 client.ApplySocketTag(tag2);
1623 // Client sends to the server.
1624 rv = WriteSocket(&client, simple_message);
1625 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1626 // Server waits for message.
1627 str = RecvFromSocket(&server);
1628 EXPECT_EQ(simple_message, str);
1629 // Server echoes reply.
1630 rv = SendToSocket(&server, simple_message);
1631 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1632 // Client waits for response.
1633 str = ReadSocket(&client);
1634 EXPECT_EQ(simple_message, str);
1635 EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
1636
1637 // Verify socket can be retagged with a new value and the current process's
1638 // UID.
1639 old_traffic = GetTaggedBytes(tag_val1);
1640 client.ApplySocketTag(tag1);
1641 // Client sends to the server.
1642 rv = WriteSocket(&client, simple_message);
1643 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1644 // Server waits for message.
1645 str = RecvFromSocket(&server);
1646 EXPECT_EQ(simple_message, str);
1647 // Server echoes reply.
1648 rv = SendToSocket(&server, simple_message);
1649 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
1650 // Client waits for response.
1651 str = ReadSocket(&client);
1652 EXPECT_EQ(simple_message, str);
1653 EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
1654 }
1655
TEST_F(UDPSocketTest,BindToNetwork)1656 TEST_F(UDPSocketTest, BindToNetwork) {
1657 // The specific value of this address doesn't really matter, and no
1658 // server needs to be running here. The test only needs to call
1659 // Connect() and won't send any datagrams.
1660 const IPEndPoint fake_server_address(IPAddress::IPv4Localhost(), 8080);
1661 NetworkChangeNotifierFactoryAndroid ncn_factory;
1662 NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
1663 std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
1664 if (!NetworkChangeNotifier::AreNetworkHandlesSupported())
1665 GTEST_SKIP() << "Network handles are required to test BindToNetwork.";
1666
1667 // Binding the socket to a not existing network should fail at connect time.
1668 const handles::NetworkHandle wrong_network_handle = 65536;
1669 UDPClientSocket wrong_socket(DatagramSocket::RANDOM_BIND, nullptr,
1670 NetLogSource(), wrong_network_handle);
1671 // Different Android versions might report different errors. Hence, just check
1672 // what shouldn't happen.
1673 int rv = wrong_socket.Connect(fake_server_address);
1674 EXPECT_NE(OK, rv);
1675 EXPECT_NE(ERR_NOT_IMPLEMENTED, rv);
1676 EXPECT_NE(wrong_network_handle, wrong_socket.GetBoundNetwork());
1677
1678 // Binding the socket to an existing network should succeed.
1679 const handles::NetworkHandle network_handle =
1680 NetworkChangeNotifier::GetDefaultNetwork();
1681 if (network_handle != handles::kInvalidNetworkHandle) {
1682 UDPClientSocket correct_socket(DatagramSocket::RANDOM_BIND, nullptr,
1683 NetLogSource(), network_handle);
1684 EXPECT_EQ(OK, correct_socket.Connect(fake_server_address));
1685 EXPECT_EQ(network_handle, correct_socket.GetBoundNetwork());
1686 }
1687 }
1688
1689 #endif // BUILDFLAG(IS_ANDROID)
1690
1691 // Scoped helper to override the process-wide UDP socket limit.
1692 class OverrideUDPSocketLimit {
1693 public:
OverrideUDPSocketLimit(int new_limit)1694 explicit OverrideUDPSocketLimit(int new_limit) {
1695 base::FieldTrialParams params;
1696 params[features::kLimitOpenUDPSocketsMax.name] =
1697 base::NumberToString(new_limit);
1698
1699 scoped_feature_list_.InitAndEnableFeatureWithParameters(
1700 features::kLimitOpenUDPSockets, params);
1701 }
1702
1703 private:
1704 base::test::ScopedFeatureList scoped_feature_list_;
1705 };
1706
1707 // Tests that UDPClientSocket respects the global UDP socket limits.
TEST_F(UDPSocketTest,LimitClientSocket)1708 TEST_F(UDPSocketTest, LimitClientSocket) {
1709 // Reduce the global UDP limit to 2.
1710 OverrideUDPSocketLimit set_limit(2);
1711
1712 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1713
1714 auto socket1 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1715 nullptr, NetLogSource());
1716 auto socket2 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1717 nullptr, NetLogSource());
1718
1719 // Simply constructing a UDPClientSocket does not increase the limit (no
1720 // Connect() or Bind() has been called yet).
1721 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1722
1723 // The specific value of this address doesn't really matter, and no server
1724 // needs to be running here. The test only needs to call Connect() and won't
1725 // send any datagrams.
1726 IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1727
1728 // Successful Connect() on socket1 increases socket count.
1729 EXPECT_THAT(socket1->Connect(server_address), IsOk());
1730 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1731
1732 // Successful Connect() on socket2 increases socket count.
1733 EXPECT_THAT(socket2->Connect(server_address), IsOk());
1734 EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1735
1736 // Attempting a third Connect() should fail with ERR_INSUFFICIENT_RESOURCES,
1737 // as the limit is currently 2.
1738 auto socket3 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1739 nullptr, NetLogSource());
1740 EXPECT_THAT(socket3->Connect(server_address),
1741 IsError(ERR_INSUFFICIENT_RESOURCES));
1742 EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1743
1744 // Check that explicitly closing socket2 free up a count.
1745 socket2->Close();
1746 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1747
1748 // Since the socket was already closed, deleting it will not affect the count.
1749 socket2.reset();
1750 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1751
1752 // Now that the count is below limit, try to connect another socket. This time
1753 // it will work.
1754 auto socket4 = std::make_unique<UDPClientSocket>(DatagramSocket::DEFAULT_BIND,
1755 nullptr, NetLogSource());
1756 EXPECT_THAT(socket4->Connect(server_address), IsOk());
1757 EXPECT_EQ(2, GetGlobalUDPSocketCountForTesting());
1758
1759 // Verify that closing the two remaining sockets brings the open count back to
1760 // 0.
1761 socket1.reset();
1762 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1763 socket4.reset();
1764 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1765 }
1766
1767 // Tests that UDPSocketClient updates the global counter
1768 // correctly when Connect() fails.
TEST_F(UDPSocketTest,LimitConnectFail)1769 TEST_F(UDPSocketTest, LimitConnectFail) {
1770 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1771
1772 {
1773 // Simply allocating a UDPSocket does not increase count.
1774 UDPSocket socket(DatagramSocket::DEFAULT_BIND, nullptr, NetLogSource());
1775 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1776
1777 // Calling Open() allocates the socket and increases the global counter.
1778 EXPECT_THAT(socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
1779 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1780
1781 // Connect to an IPv6 address should fail since the socket was created for
1782 // IPv4.
1783 EXPECT_THAT(socket.Connect(net::IPEndPoint(IPAddress::IPv6Localhost(), 53)),
1784 Not(IsOk()));
1785
1786 // That Connect() failed doesn't change the global counter.
1787 EXPECT_EQ(1, GetGlobalUDPSocketCountForTesting());
1788 }
1789
1790 // Finally, destroying UDPSocket decrements the global counter.
1791 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1792 }
1793
1794 // Tests allocating UDPClientSockets and Connect()ing them in parallel.
1795 //
1796 // This is primarily intended for coverage under TSAN, to check for races
1797 // enforcing the global socket counter.
TEST_F(UDPSocketTest,LimitConnectMultithreaded)1798 TEST_F(UDPSocketTest, LimitConnectMultithreaded) {
1799 ASSERT_EQ(0, GetGlobalUDPSocketCountForTesting());
1800
1801 // Start up some threads.
1802 std::vector<std::unique_ptr<base::Thread>> threads;
1803 for (size_t i = 0; i < 5; ++i) {
1804 threads.push_back(std::make_unique<base::Thread>("Worker thread"));
1805 ASSERT_TRUE(threads.back()->Start());
1806 }
1807
1808 // Post tasks to each of the threads.
1809 for (const auto& thread : threads) {
1810 thread->task_runner()->PostTask(
1811 FROM_HERE, base::BindOnce([] {
1812 // The specific value of this address doesn't really matter, and no
1813 // server needs to be running here. The test only needs to call
1814 // Connect() and won't send any datagrams.
1815 IPEndPoint server_address(IPAddress::IPv4Localhost(), 8080);
1816
1817 UDPClientSocket socket(DatagramSocket::DEFAULT_BIND, nullptr,
1818 NetLogSource());
1819 EXPECT_THAT(socket.Connect(server_address), IsOk());
1820 }));
1821 }
1822
1823 // Complete all the tasks.
1824 threads.clear();
1825
1826 EXPECT_EQ(0, GetGlobalUDPSocketCountForTesting());
1827 }
1828
1829 } // namespace net
1830