xref: /aosp_15_r20/external/cronet/net/socket/socket_test_util.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socket_test_util.h"
6 #include "base/memory/raw_ptr.h"
7 
8 #include <inttypes.h>  // For SCNx64
9 #include <stdint.h>
10 #include <stdio.h>
11 
12 #include <memory>
13 #include <ostream>
14 #include <string>
15 #include <string_view>
16 #include <utility>
17 #include <vector>
18 
19 #include "base/compiler_specific.h"
20 #include "base/files/file_util.h"
21 #include "base/functional/bind.h"
22 #include "base/functional/callback_helpers.h"
23 #include "base/location.h"
24 #include "base/logging.h"
25 #include "base/rand_util.h"
26 #include "base/ranges/algorithm.h"
27 #include "base/run_loop.h"
28 #include "base/task/single_thread_task_runner.h"
29 #include "base/time/time.h"
30 #include "build/build_config.h"
31 #include "net/base/address_family.h"
32 #include "net/base/address_list.h"
33 #include "net/base/auth.h"
34 #include "net/base/hex_utils.h"
35 #include "net/base/ip_address.h"
36 #include "net/base/load_timing_info.h"
37 #include "net/base/proxy_server.h"
38 #include "net/http/http_network_session.h"
39 #include "net/http/http_request_headers.h"
40 #include "net/http/http_response_headers.h"
41 #include "net/log/net_log_source.h"
42 #include "net/log/net_log_source_type.h"
43 #include "net/socket/connect_job.h"
44 #include "net/socket/socket.h"
45 #include "net/socket/stream_socket.h"
46 #include "net/socket/websocket_endpoint_lock_manager.h"
47 #include "net/ssl/ssl_cert_request_info.h"
48 #include "net/ssl/ssl_connection_status_flags.h"
49 #include "net/ssl/ssl_info.h"
50 #include "net/traffic_annotation/network_traffic_annotation.h"
51 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
52 #include "testing/gtest/include/gtest/gtest.h"
53 #include "third_party/abseil-cpp/absl/strings/ascii.h"
54 
55 #if BUILDFLAG(IS_ANDROID)
56 #include "base/android/build_info.h"
57 #endif
58 
59 #define NET_TRACE(level, s) VLOG(level) << s << __FUNCTION__ << "() "
60 
61 namespace net {
62 namespace {
63 
AsciifyHigh(char x)64 inline char AsciifyHigh(char x) {
65   char nybble = static_cast<char>((x >> 4) & 0x0F);
66   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
67 }
68 
AsciifyLow(char x)69 inline char AsciifyLow(char x) {
70   char nybble = static_cast<char>((x >> 0) & 0x0F);
71   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
72 }
73 
Asciify(char x)74 inline char Asciify(char x) {
75   return absl::ascii_isprint(static_cast<unsigned char>(x)) ? x : '.';
76 }
77 
DumpData(const char * data,int data_len)78 void DumpData(const char* data, int data_len) {
79   if (logging::LOGGING_INFO < logging::GetMinLogLevel()) {
80     return;
81   }
82   DVLOG(1) << "Length:  " << data_len;
83   const char* pfx = "Data:    ";
84   if (!data || (data_len <= 0)) {
85     DVLOG(1) << pfx << "<None>";
86   } else {
87     int i;
88     for (i = 0; i <= (data_len - 4); i += 4) {
89       DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
90                << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
91                << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
92                << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3]) << "  '"
93                << Asciify(data[i + 0]) << Asciify(data[i + 1])
94                << Asciify(data[i + 2]) << Asciify(data[i + 3]) << "'";
95       pfx = "         ";
96     }
97     // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
98     switch (data_len - i) {
99       case 3:
100         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
101                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
102                  << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
103                  << "    '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
104                  << Asciify(data[i + 2]) << " '";
105         break;
106       case 2:
107         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
108                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
109                  << "      '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
110                  << "  '";
111         break;
112       case 1:
113         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
114                  << "        '" << Asciify(data[i + 0]) << "   '";
115         break;
116     }
117   }
118 }
119 
120 template <MockReadWriteType type>
DumpMockReadWrite(const MockReadWrite<type> & r)121 void DumpMockReadWrite(const MockReadWrite<type>& r) {
122   if (logging::LOGGING_INFO < logging::GetMinLogLevel()) {
123     return;
124   }
125   DVLOG(1) << "Async:   " << (r.mode == ASYNC) << "\nResult:  " << r.result;
126   DumpData(r.data, r.data_len);
127   const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
128   DVLOG(1) << "Stage:   " << (r.sequence_number & ~MockRead::STOPLOOP) << stop;
129 }
130 
RunClosureIfNonNull(base::OnceClosure closure)131 void RunClosureIfNonNull(base::OnceClosure closure) {
132   if (!closure.is_null()) {
133     std::move(closure).Run();
134   }
135 }
136 
137 }  // namespace
138 
MockConnect()139 MockConnect::MockConnect() : mode(ASYNC), result(OK) {
140   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
141 }
142 
MockConnect(IoMode io_mode,int r)143 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
144   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
145 }
146 
MockConnect(IoMode io_mode,int r,IPEndPoint addr)147 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr)
148     : mode(io_mode), result(r), peer_addr(addr) {}
149 
MockConnect(IoMode io_mode,int r,IPEndPoint addr,bool first_attempt_fails)150 MockConnect::MockConnect(IoMode io_mode,
151                          int r,
152                          IPEndPoint addr,
153                          bool first_attempt_fails)
154     : mode(io_mode),
155       result(r),
156       peer_addr(addr),
157       first_attempt_fails(first_attempt_fails) {}
158 
159 MockConnect::~MockConnect() = default;
160 
MockConfirm()161 MockConfirm::MockConfirm() : mode(SYNCHRONOUS), result(OK) {}
162 
MockConfirm(IoMode io_mode,int r)163 MockConfirm::MockConfirm(IoMode io_mode, int r) : mode(io_mode), result(r) {}
164 
165 MockConfirm::~MockConfirm() = default;
166 
IsIdle() const167 bool SocketDataProvider::IsIdle() const {
168   return true;
169 }
170 
Initialize(AsyncSocket * socket)171 void SocketDataProvider::Initialize(AsyncSocket* socket) {
172   CHECK(!socket_);
173   CHECK(socket);
174   socket_ = socket;
175   Reset();
176 }
177 
DetachSocket()178 void SocketDataProvider::DetachSocket() {
179   CHECK(socket_);
180   socket_ = nullptr;
181 }
182 
183 SocketDataProvider::SocketDataProvider() = default;
184 
~SocketDataProvider()185 SocketDataProvider::~SocketDataProvider() {
186   if (socket_)
187     socket_->OnDataProviderDestroyed();
188 }
189 
StaticSocketDataHelper(base::span<const MockRead> reads,base::span<const MockWrite> writes)190 StaticSocketDataHelper::StaticSocketDataHelper(
191     base::span<const MockRead> reads,
192     base::span<const MockWrite> writes)
193     : reads_(reads), writes_(writes) {}
194 
195 StaticSocketDataHelper::~StaticSocketDataHelper() = default;
196 
PeekRead() const197 const MockRead& StaticSocketDataHelper::PeekRead() const {
198   CHECK(!AllReadDataConsumed());
199   return reads_[read_index_];
200 }
201 
PeekWrite() const202 const MockWrite& StaticSocketDataHelper::PeekWrite() const {
203   CHECK(!AllWriteDataConsumed());
204   return writes_[write_index_];
205 }
206 
AdvanceRead()207 const MockRead& StaticSocketDataHelper::AdvanceRead() {
208   CHECK(!AllReadDataConsumed());
209   return reads_[read_index_++];
210 }
211 
AdvanceWrite()212 const MockWrite& StaticSocketDataHelper::AdvanceWrite() {
213   CHECK(!AllWriteDataConsumed());
214   return writes_[write_index_++];
215 }
216 
Reset()217 void StaticSocketDataHelper::Reset() {
218   read_index_ = 0;
219   write_index_ = 0;
220 }
221 
VerifyWriteData(const std::string & data,SocketDataPrinter * printer)222 bool StaticSocketDataHelper::VerifyWriteData(const std::string& data,
223                                              SocketDataPrinter* printer) {
224   CHECK(!AllWriteDataConsumed());
225   // Check that the actual data matches the expectations, skipping over any
226   // pause events.
227   const MockWrite& next_write = PeekRealWrite();
228   if (!next_write.data)
229     return true;
230 
231   // Note: Partial writes are supported here.  If the expected data
232   // is a match, but shorter than the write actually written, that is legal.
233   // Example:
234   //   Application writes "foobarbaz" (9 bytes)
235   //   Expected write was "foo" (3 bytes)
236   //   This is a success, and the function returns true.
237   std::string expected_data(next_write.data, next_write.data_len);
238   std::string actual_data(data.substr(0, next_write.data_len));
239   if (printer) {
240     EXPECT_TRUE(actual_data == expected_data)
241         << "Actual formatted write data:\n"
242         << printer->PrintWrite(data) << "Expected formatted write data:\n"
243         << printer->PrintWrite(expected_data) << "Actual raw write data:\n"
244         << HexDump(data) << "Expected raw write data:\n"
245         << HexDump(expected_data);
246   } else {
247     EXPECT_TRUE(actual_data == expected_data)
248         << "Actual write data:\n"
249         << HexDump(data) << "Expected write data:\n"
250         << HexDump(expected_data);
251   }
252   return expected_data == actual_data;
253 }
254 
ExpectAllReadDataConsumed(SocketDataPrinter * printer) const255 void StaticSocketDataHelper::ExpectAllReadDataConsumed(
256     SocketDataPrinter* printer) const {
257   if (AllReadDataConsumed()) {
258     return;
259   }
260 
261   std::ostringstream msg;
262   if (read_index_ < read_count()) {
263     msg << "Unconsumed reads:\n";
264     for (size_t i = read_index_; i < read_count(); i++) {
265       msg << (reads_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockRead seq "
266           << reads_[i].sequence_number << ":\n";
267       if (reads_[i].result != OK) {
268         msg << "Result: " << reads_[i].result << "\n";
269       }
270       if (reads_[i].data) {
271         std::string data(reads_[i].data, reads_[i].data_len);
272         if (printer) {
273           msg << printer->PrintWrite(data);
274         }
275         msg << HexDump(data);
276       }
277     }
278   }
279   EXPECT_TRUE(AllReadDataConsumed()) << msg.str();
280 }
281 
ExpectAllWriteDataConsumed(SocketDataPrinter * printer) const282 void StaticSocketDataHelper::ExpectAllWriteDataConsumed(
283     SocketDataPrinter* printer) const {
284   if (AllWriteDataConsumed()) {
285     return;
286   }
287 
288   std::ostringstream msg;
289   if (write_index_ < write_count()) {
290     msg << "Unconsumed writes:\n";
291     for (size_t i = write_index_; i < write_count(); i++) {
292       msg << (writes_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockWrite seq "
293           << writes_[i].sequence_number << ":\n";
294       if (writes_[i].result != OK) {
295         msg << "Result: " << writes_[i].result << "\n";
296       }
297       if (writes_[i].data) {
298         std::string data(writes_[i].data, writes_[i].data_len);
299         if (printer) {
300           msg << printer->PrintWrite(data);
301         }
302         msg << HexDump(data);
303       }
304     }
305   }
306   EXPECT_TRUE(AllWriteDataConsumed()) << msg.str();
307 }
308 
PeekRealWrite() const309 const MockWrite& StaticSocketDataHelper::PeekRealWrite() const {
310   for (size_t i = write_index_; i < write_count(); i++) {
311     if (writes_[i].mode != ASYNC || writes_[i].result != ERR_IO_PENDING)
312       return writes_[i];
313   }
314 
315   CHECK(false) << "No write data available.";
316   return writes_[0];  // Avoid warning about unreachable missing return.
317 }
318 
StaticSocketDataProvider()319 StaticSocketDataProvider::StaticSocketDataProvider()
320     : StaticSocketDataProvider(base::span<const MockRead>(),
321                                base::span<const MockWrite>()) {}
322 
StaticSocketDataProvider(base::span<const MockRead> reads,base::span<const MockWrite> writes)323 StaticSocketDataProvider::StaticSocketDataProvider(
324     base::span<const MockRead> reads,
325     base::span<const MockWrite> writes)
326     : helper_(reads, writes) {}
327 
328 StaticSocketDataProvider::~StaticSocketDataProvider() = default;
329 
Pause()330 void StaticSocketDataProvider::Pause() {
331   paused_ = true;
332 }
333 
Resume()334 void StaticSocketDataProvider::Resume() {
335   paused_ = false;
336 }
337 
OnRead()338 MockRead StaticSocketDataProvider::OnRead() {
339   if (AllReadDataConsumed()) {
340     const net::MockRead pending_read(net::SYNCHRONOUS, net::ERR_IO_PENDING);
341     return pending_read;
342   }
343 
344   return helper_.AdvanceRead();
345 }
346 
OnWrite(const std::string & data)347 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
348   if (helper_.write_count() == 0) {
349     // Not using mock writes; succeed synchronously.
350     return MockWriteResult(SYNCHRONOUS, data.length());
351   }
352   if (printer_) {
353     EXPECT_FALSE(helper_.AllWriteDataConsumed())
354         << "No more mock data to match write:\nFormatted write data:\n"
355         << printer_->PrintWrite(data) << "Raw write data:\n"
356         << HexDump(data);
357   } else {
358     EXPECT_FALSE(helper_.AllWriteDataConsumed())
359         << "No more mock data to match write:\nRaw write data:\n"
360         << HexDump(data);
361   }
362   if (helper_.AllWriteDataConsumed()) {
363     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
364   }
365 
366   // Check that what we are writing matches the expectation.
367   // Then give the mocked return value.
368   if (!helper_.VerifyWriteData(data, printer_))
369     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
370 
371   const MockWrite& next_write = helper_.AdvanceWrite();
372   // In the case that the write was successful, return the number of bytes
373   // written. Otherwise return the error code.
374   int result =
375       next_write.result == OK ? next_write.data_len : next_write.result;
376   return MockWriteResult(next_write.mode, result);
377 }
378 
AllReadDataConsumed() const379 bool StaticSocketDataProvider::AllReadDataConsumed() const {
380   return paused_ || helper_.AllReadDataConsumed();
381 }
382 
AllWriteDataConsumed() const383 bool StaticSocketDataProvider::AllWriteDataConsumed() const {
384   return helper_.AllWriteDataConsumed();
385 }
386 
Reset()387 void StaticSocketDataProvider::Reset() {
388   helper_.Reset();
389 }
390 
SSLSocketDataProvider(IoMode mode,int result)391 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
392     : connect(mode, result),
393       expected_ssl_version_min(kDefaultSSLVersionMin),
394       expected_ssl_version_max(kDefaultSSLVersionMax) {
395   SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3,
396                                 &ssl_info.connection_status);
397   // Set to TLS_CHACHA20_POLY1305_SHA256
398   SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status);
399 }
400 
401 SSLSocketDataProvider::SSLSocketDataProvider(
402     const SSLSocketDataProvider& other) = default;
403 
404 SSLSocketDataProvider::~SSLSocketDataProvider() = default;
405 
SequencedSocketData()406 SequencedSocketData::SequencedSocketData()
407     : SequencedSocketData(base::span<const MockRead>(),
408                           base::span<const MockWrite>()) {}
409 
SequencedSocketData(base::span<const MockRead> reads,base::span<const MockWrite> writes)410 SequencedSocketData::SequencedSocketData(base::span<const MockRead> reads,
411                                          base::span<const MockWrite> writes)
412     : helper_(reads, writes) {
413   // Check that reads and writes have a contiguous set of sequence numbers
414   // starting from 0 and working their way up, with no repeats and skipping
415   // no values.
416   int next_sequence_number = 0;
417   bool last_event_was_pause = false;
418 
419   auto next_read = reads.begin();
420   auto next_write = writes.begin();
421   while (next_read != reads.end() || next_write != writes.end()) {
422     if (next_read != reads.end() &&
423         next_read->sequence_number == next_sequence_number) {
424       // Check if this is a pause.
425       if (next_read->mode == ASYNC && next_read->result == ERR_IO_PENDING) {
426         CHECK(!last_event_was_pause)
427             << "Two pauses in a row are not allowed: " << next_sequence_number;
428         last_event_was_pause = true;
429       } else if (last_event_was_pause) {
430         CHECK_EQ(ASYNC, next_read->mode)
431             << "A sync event after a pause makes no sense: "
432             << next_sequence_number;
433         CHECK_NE(ERR_IO_PENDING, next_read->result)
434             << "A pause event after a pause makes no sense: "
435             << next_sequence_number;
436         last_event_was_pause = false;
437       }
438 
439       ++next_read;
440       ++next_sequence_number;
441       continue;
442     }
443     if (next_write != writes.end() &&
444         next_write->sequence_number == next_sequence_number) {
445       // Check if this is a pause.
446       if (next_write->mode == ASYNC && next_write->result == ERR_IO_PENDING) {
447         CHECK(!last_event_was_pause)
448             << "Two pauses in a row are not allowed: " << next_sequence_number;
449         last_event_was_pause = true;
450       } else if (last_event_was_pause) {
451         CHECK_EQ(ASYNC, next_write->mode)
452             << "A sync event after a pause makes no sense: "
453             << next_sequence_number;
454         CHECK_NE(ERR_IO_PENDING, next_write->result)
455             << "A pause event after a pause makes no sense: "
456             << next_sequence_number;
457         last_event_was_pause = false;
458       }
459 
460       ++next_write;
461       ++next_sequence_number;
462       continue;
463     }
464     if (next_write != writes.end()) {
465       CHECK(false) << "Sequence number " << next_write->sequence_number
466                    << " not found where expected: " << next_sequence_number;
467     } else {
468       CHECK(false) << "Too few writes, next expected sequence number: "
469                    << next_sequence_number;
470     }
471     return;
472   }
473 
474   // Last event must not be a pause.  For the final event to indicate the
475   // operation never completes, it should be SYNCHRONOUS and return
476   // ERR_IO_PENDING.
477   CHECK(!last_event_was_pause);
478 
479   CHECK(next_read == reads.end());
480   CHECK(next_write == writes.end());
481 }
482 
SequencedSocketData(const MockConnect & connect,base::span<const MockRead> reads,base::span<const MockWrite> writes)483 SequencedSocketData::SequencedSocketData(const MockConnect& connect,
484                                          base::span<const MockRead> reads,
485                                          base::span<const MockWrite> writes)
486     : SequencedSocketData(reads, writes) {
487   set_connect_data(connect);
488 }
OnRead()489 MockRead SequencedSocketData::OnRead() {
490   CHECK_EQ(IoState::kIdle, read_state_);
491   CHECK(!helper_.AllReadDataConsumed())
492       << "Application tried to read but there is no read data left";
493 
494   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
495   const MockRead& next_read = helper_.PeekRead();
496   NET_TRACE(1, " *** ") << "next_read: " << next_read.sequence_number;
497   CHECK_GE(next_read.sequence_number, sequence_number_);
498 
499   if (next_read.sequence_number <= sequence_number_) {
500     if (next_read.mode == SYNCHRONOUS) {
501       NET_TRACE(1, " *** ") << "Returning synchronously";
502       DumpMockReadWrite(next_read);
503       helper_.AdvanceRead();
504       ++sequence_number_;
505       MaybePostWriteCompleteTask();
506       return next_read;
507     }
508 
509     // If the result is ERR_IO_PENDING, then pause.
510     if (next_read.result == ERR_IO_PENDING) {
511       NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
512       read_state_ = IoState::kPaused;
513       if (run_until_paused_run_loop_)
514         run_until_paused_run_loop_->Quit();
515       return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
516     }
517     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
518         FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
519                                   weak_factory_.GetWeakPtr()));
520     CHECK_NE(IoState::kCompleting, write_state_);
521     read_state_ = IoState::kCompleting;
522   } else if (next_read.mode == SYNCHRONOUS) {
523     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
524     return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
525   } else {
526     NET_TRACE(1, " *** ") << "Waiting for write to trigger read";
527     read_state_ = IoState::kPending;
528   }
529 
530   return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
531 }
532 
OnWrite(const std::string & data)533 MockWriteResult SequencedSocketData::OnWrite(const std::string& data) {
534   CHECK_EQ(IoState::kIdle, write_state_);
535   if (printer_) {
536     CHECK(!helper_.AllWriteDataConsumed())
537         << "\nNo more mock data to match write:\nFormatted write data:\n"
538         << printer_->PrintWrite(data) << "Raw write data:\n"
539         << HexDump(data);
540   } else {
541     CHECK(!helper_.AllWriteDataConsumed())
542         << "\nNo more mock data to match write:\nRaw write data:\n"
543         << HexDump(data);
544   }
545 
546   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
547   const MockWrite& next_write = helper_.PeekWrite();
548   NET_TRACE(1, " *** ") << "next_write: " << next_write.sequence_number;
549   CHECK_GE(next_write.sequence_number, sequence_number_);
550 
551   if (!helper_.VerifyWriteData(data, printer_))
552     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
553 
554   if (next_write.sequence_number <= sequence_number_) {
555     if (next_write.mode == SYNCHRONOUS) {
556       helper_.AdvanceWrite();
557       ++sequence_number_;
558       MaybePostReadCompleteTask();
559       // In the case that the write was successful, return the number of bytes
560       // written. Otherwise return the error code.
561       int rv =
562           next_write.result != OK ? next_write.result : next_write.data_len;
563       NET_TRACE(1, " *** ") << "Returning synchronously";
564       return MockWriteResult(SYNCHRONOUS, rv);
565     }
566 
567     // If the result is ERR_IO_PENDING, then pause.
568     if (next_write.result == ERR_IO_PENDING) {
569       NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
570       write_state_ = IoState::kPaused;
571       if (run_until_paused_run_loop_)
572         run_until_paused_run_loop_->Quit();
573       return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
574     }
575 
576     NET_TRACE(1, " *** ") << "Posting task to complete write";
577     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
578         FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
579                                   weak_factory_.GetWeakPtr()));
580     CHECK_NE(IoState::kCompleting, read_state_);
581     write_state_ = IoState::kCompleting;
582   } else if (next_write.mode == SYNCHRONOUS) {
583     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
584     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
585   } else {
586     NET_TRACE(1, " *** ") << "Waiting for read to trigger write";
587     write_state_ = IoState::kPending;
588   }
589 
590   return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
591 }
592 
AllReadDataConsumed() const593 bool SequencedSocketData::AllReadDataConsumed() const {
594   return helper_.AllReadDataConsumed();
595 }
596 
CancelPendingRead()597 void SequencedSocketData::CancelPendingRead() {
598   DCHECK_EQ(IoState::kPending, read_state_);
599 
600   read_state_ = IoState::kIdle;
601 }
602 
AllWriteDataConsumed() const603 bool SequencedSocketData::AllWriteDataConsumed() const {
604   return helper_.AllWriteDataConsumed();
605 }
606 
ExpectAllReadDataConsumed() const607 void SequencedSocketData::ExpectAllReadDataConsumed() const {
608   helper_.ExpectAllReadDataConsumed(printer_.get());
609 }
610 
ExpectAllWriteDataConsumed() const611 void SequencedSocketData::ExpectAllWriteDataConsumed() const {
612   helper_.ExpectAllWriteDataConsumed(printer_.get());
613 }
614 
IsIdle() const615 bool SequencedSocketData::IsIdle() const {
616   // If |busy_before_sync_reads_| is not set, always considered idle.  If
617   // no reads left, or the next operation is a write, also consider it idle.
618   if (!busy_before_sync_reads_ || helper_.AllReadDataConsumed() ||
619       helper_.PeekRead().sequence_number != sequence_number_) {
620     return true;
621   }
622 
623   // If the next operation is synchronous read, treat the socket as not idle.
624   if (helper_.PeekRead().mode == SYNCHRONOUS)
625     return false;
626   return true;
627 }
628 
IsPaused() const629 bool SequencedSocketData::IsPaused() const {
630   // Both states should not be paused.
631   DCHECK(read_state_ != IoState::kPaused || write_state_ != IoState::kPaused);
632   return write_state_ == IoState::kPaused || read_state_ == IoState::kPaused;
633 }
634 
Resume()635 void SequencedSocketData::Resume() {
636   if (!IsPaused()) {
637     ADD_FAILURE() << "Unable to Resume when not paused.";
638     return;
639   }
640 
641   sequence_number_++;
642   if (read_state_ == IoState::kPaused) {
643     read_state_ = IoState::kPending;
644     helper_.AdvanceRead();
645   } else {  // write_state_ == IoState::kPaused
646     write_state_ = IoState::kPending;
647     helper_.AdvanceWrite();
648   }
649 
650   if (!helper_.AllWriteDataConsumed() &&
651       helper_.PeekWrite().sequence_number == sequence_number_) {
652     // The next event hasn't even started yet.  Pausing isn't really needed in
653     // that case, but may as well support it.
654     if (write_state_ != IoState::kPending)
655       return;
656     write_state_ = IoState::kCompleting;
657     OnWriteComplete();
658     return;
659   }
660 
661   CHECK(!helper_.AllReadDataConsumed());
662 
663   // The next event hasn't even started yet.  Pausing isn't really needed in
664   // that case, but may as well support it.
665   if (read_state_ != IoState::kPending)
666     return;
667   read_state_ = IoState::kCompleting;
668   OnReadComplete();
669 }
670 
RunUntilPaused()671 void SequencedSocketData::RunUntilPaused() {
672   CHECK(!run_until_paused_run_loop_);
673 
674   if (IsPaused())
675     return;
676 
677   run_until_paused_run_loop_ = std::make_unique<base::RunLoop>();
678   run_until_paused_run_loop_->Run();
679   run_until_paused_run_loop_.reset();
680   DCHECK(IsPaused());
681 }
682 
MaybePostReadCompleteTask()683 void SequencedSocketData::MaybePostReadCompleteTask() {
684   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
685   // Only trigger the next read to complete if there is already a read pending
686   // which should complete at the current sequence number.
687   if (read_state_ != IoState::kPending ||
688       helper_.PeekRead().sequence_number != sequence_number_) {
689     return;
690   }
691 
692   // If the result is ERR_IO_PENDING, then pause.
693   if (helper_.PeekRead().result == ERR_IO_PENDING) {
694     NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
695     read_state_ = IoState::kPaused;
696     if (run_until_paused_run_loop_)
697       run_until_paused_run_loop_->Quit();
698     return;
699   }
700 
701   NET_TRACE(1, " ****** ") << "Posting task to complete read: "
702                            << sequence_number_;
703   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
704       FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
705                                 weak_factory_.GetWeakPtr()));
706   CHECK_NE(IoState::kCompleting, write_state_);
707   read_state_ = IoState::kCompleting;
708 }
709 
MaybePostWriteCompleteTask()710 void SequencedSocketData::MaybePostWriteCompleteTask() {
711   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
712   // Only trigger the next write to complete if there is already a write pending
713   // which should complete at the current sequence number.
714   if (write_state_ != IoState::kPending ||
715       helper_.PeekWrite().sequence_number != sequence_number_) {
716     return;
717   }
718 
719   // If the result is ERR_IO_PENDING, then pause.
720   if (helper_.PeekWrite().result == ERR_IO_PENDING) {
721     NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
722     write_state_ = IoState::kPaused;
723     if (run_until_paused_run_loop_)
724       run_until_paused_run_loop_->Quit();
725     return;
726   }
727 
728   NET_TRACE(1, " ****** ") << "Posting task to complete write: "
729                            << sequence_number_;
730   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
731       FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
732                                 weak_factory_.GetWeakPtr()));
733   CHECK_NE(IoState::kCompleting, read_state_);
734   write_state_ = IoState::kCompleting;
735 }
736 
Reset()737 void SequencedSocketData::Reset() {
738   helper_.Reset();
739   sequence_number_ = 0;
740   read_state_ = IoState::kIdle;
741   write_state_ = IoState::kIdle;
742   weak_factory_.InvalidateWeakPtrs();
743 }
744 
OnReadComplete()745 void SequencedSocketData::OnReadComplete() {
746   CHECK_EQ(IoState::kCompleting, read_state_);
747   NET_TRACE(1, " *** ") << "Completing read for: " << sequence_number_;
748 
749   MockRead data = helper_.AdvanceRead();
750   DCHECK_EQ(sequence_number_, data.sequence_number);
751   sequence_number_++;
752   read_state_ = IoState::kIdle;
753 
754   // The result of this read completing might trigger the completion
755   // of a pending write. If so, post a task to complete the write later.
756   // Since the socket may call back into the SequencedSocketData
757   // from socket()->OnReadComplete(), trigger the write task to be posted
758   // before calling that.
759   MaybePostWriteCompleteTask();
760 
761   if (!socket()) {
762     NET_TRACE(1, " *** ") << "No socket available to complete read";
763     return;
764   }
765 
766   NET_TRACE(1, " *** ") << "Completing socket read for: "
767                         << data.sequence_number;
768   DumpMockReadWrite(data);
769   socket()->OnReadComplete(data);
770   NET_TRACE(1, " *** ") << "Done";
771 }
772 
OnWriteComplete()773 void SequencedSocketData::OnWriteComplete() {
774   CHECK_EQ(IoState::kCompleting, write_state_);
775   NET_TRACE(1, " *** ") << " Completing write for: " << sequence_number_;
776 
777   const MockWrite& data = helper_.AdvanceWrite();
778   DCHECK_EQ(sequence_number_, data.sequence_number);
779   sequence_number_++;
780   write_state_ = IoState::kIdle;
781   int rv = data.result == OK ? data.data_len : data.result;
782 
783   // The result of this write completing might trigger the completion
784   // of a pending read. If so, post a task to complete the read later.
785   // Since the socket may call back into the SequencedSocketData
786   // from socket()->OnWriteComplete(), trigger the write task to be posted
787   // before calling that.
788   MaybePostReadCompleteTask();
789 
790   if (!socket()) {
791     NET_TRACE(1, " *** ") << "No socket available to complete write";
792     return;
793   }
794 
795   NET_TRACE(1, " *** ") << " Completing socket write for: "
796                         << data.sequence_number;
797   socket()->OnWriteComplete(rv);
798   NET_TRACE(1, " *** ") << "Done";
799 }
800 
801 SequencedSocketData::~SequencedSocketData() = default;
802 
803 MockClientSocketFactory::MockClientSocketFactory() = default;
804 
805 MockClientSocketFactory::~MockClientSocketFactory() = default;
806 
AddSocketDataProvider(SocketDataProvider * data)807 void MockClientSocketFactory::AddSocketDataProvider(SocketDataProvider* data) {
808   mock_data_.Add(data);
809 }
810 
AddTcpSocketDataProvider(SocketDataProvider * data)811 void MockClientSocketFactory::AddTcpSocketDataProvider(
812     SocketDataProvider* data) {
813   mock_tcp_data_.Add(data);
814 }
815 
AddSSLSocketDataProvider(SSLSocketDataProvider * data)816 void MockClientSocketFactory::AddSSLSocketDataProvider(
817     SSLSocketDataProvider* data) {
818   mock_ssl_data_.Add(data);
819 }
820 
ResetNextMockIndexes()821 void MockClientSocketFactory::ResetNextMockIndexes() {
822   mock_data_.ResetNextIndex();
823   mock_ssl_data_.ResetNextIndex();
824 }
825 
826 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)827 MockClientSocketFactory::CreateDatagramClientSocket(
828     DatagramSocket::BindType bind_type,
829     NetLog* net_log,
830     const NetLogSource& source) {
831   SocketDataProvider* data_provider = mock_data_.GetNext();
832   auto socket = std::make_unique<MockUDPClientSocket>(data_provider, net_log);
833   if (bind_type == DatagramSocket::RANDOM_BIND)
834     socket->set_source_port(static_cast<uint16_t>(base::RandInt(1025, 65535)));
835   udp_client_socket_ports_.push_back(socket->source_port());
836   return std::move(socket);
837 }
838 
839 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)840 MockClientSocketFactory::CreateTransportClientSocket(
841     const AddressList& addresses,
842     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
843     NetworkQualityEstimator* network_quality_estimator,
844     NetLog* net_log,
845     const NetLogSource& source) {
846   SocketDataProvider* data_provider = mock_tcp_data_.GetNextWithoutAsserting();
847   if (!data_provider)
848     data_provider = mock_data_.GetNext();
849   auto socket =
850       std::make_unique<MockTCPClientSocket>(addresses, net_log, data_provider);
851   if (enable_read_if_ready_)
852     socket->set_enable_read_if_ready(enable_read_if_ready_);
853   return std::move(socket);
854 }
855 
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)856 std::unique_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
857     SSLClientContext* context,
858     std::unique_ptr<StreamSocket> stream_socket,
859     const HostPortPair& host_and_port,
860     const SSLConfig& ssl_config) {
861   SSLSocketDataProvider* next_ssl_data = mock_ssl_data_.GetNext();
862   if (next_ssl_data->next_protos_expected_in_ssl_config.has_value()) {
863     EXPECT_TRUE(base::ranges::equal(
864         next_ssl_data->next_protos_expected_in_ssl_config.value(),
865         ssl_config.alpn_protos));
866   }
867   if (next_ssl_data->expected_application_settings) {
868     EXPECT_EQ(*next_ssl_data->expected_application_settings,
869               ssl_config.application_settings);
870   }
871 
872   // The protocol version used is a combination of the per-socket SSLConfig and
873   // the SSLConfigService.
874   EXPECT_EQ(
875       next_ssl_data->expected_ssl_version_min,
876       ssl_config.version_min_override.value_or(context->config().version_min));
877   EXPECT_EQ(
878       next_ssl_data->expected_ssl_version_max,
879       ssl_config.version_max_override.value_or(context->config().version_max));
880 
881   if (next_ssl_data->expected_early_data_enabled) {
882     EXPECT_EQ(*next_ssl_data->expected_early_data_enabled,
883               ssl_config.early_data_enabled);
884   }
885 
886   if (next_ssl_data->expected_send_client_cert) {
887     // Client certificate preferences come from |context|.
888     scoped_refptr<X509Certificate> client_cert;
889     scoped_refptr<SSLPrivateKey> client_private_key;
890     bool send_client_cert = context->GetClientCertificate(
891         host_and_port, &client_cert, &client_private_key);
892 
893     EXPECT_EQ(*next_ssl_data->expected_send_client_cert, send_client_cert);
894     // Note |send_client_cert| may be true while |client_cert| is null if the
895     // socket is configured to continue without a certificate, as opposed to
896     // surfacing the certificate challenge.
897     EXPECT_EQ(!!next_ssl_data->expected_client_cert, !!client_cert);
898     if (next_ssl_data->expected_client_cert && client_cert) {
899       EXPECT_TRUE(next_ssl_data->expected_client_cert->EqualsIncludingChain(
900           client_cert.get()));
901     }
902   }
903   if (next_ssl_data->expected_host_and_port) {
904     EXPECT_EQ(*next_ssl_data->expected_host_and_port, host_and_port);
905   }
906   if (next_ssl_data->expected_ignore_certificate_errors) {
907     EXPECT_EQ(*next_ssl_data->expected_ignore_certificate_errors,
908               ssl_config.ignore_certificate_errors);
909   }
910   if (next_ssl_data->expected_network_anonymization_key) {
911     EXPECT_EQ(*next_ssl_data->expected_network_anonymization_key,
912               ssl_config.network_anonymization_key);
913   }
914   if (next_ssl_data->expected_ech_config_list) {
915     EXPECT_EQ(*next_ssl_data->expected_ech_config_list,
916               ssl_config.ech_config_list);
917   }
918   return std::make_unique<MockSSLClientSocket>(
919       std::move(stream_socket), host_and_port, ssl_config, next_ssl_data);
920 }
921 
MockClientSocket(const NetLogWithSource & net_log)922 MockClientSocket::MockClientSocket(const NetLogWithSource& net_log)
923     : net_log_(net_log) {
924   local_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
925   peer_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
926 }
927 
SetReceiveBufferSize(int32_t size)928 int MockClientSocket::SetReceiveBufferSize(int32_t size) {
929   return OK;
930 }
931 
SetSendBufferSize(int32_t size)932 int MockClientSocket::SetSendBufferSize(int32_t size) {
933   return OK;
934 }
935 
Bind(const net::IPEndPoint & local_addr)936 int MockClientSocket::Bind(const net::IPEndPoint& local_addr) {
937   local_addr_ = local_addr;
938   return net::OK;
939 }
940 
SetNoDelay(bool no_delay)941 bool MockClientSocket::SetNoDelay(bool no_delay) {
942   return true;
943 }
944 
SetKeepAlive(bool enable,int delay)945 bool MockClientSocket::SetKeepAlive(bool enable, int delay) {
946   return true;
947 }
948 
Disconnect()949 void MockClientSocket::Disconnect() {
950   connected_ = false;
951 }
952 
IsConnected() const953 bool MockClientSocket::IsConnected() const {
954   return connected_;
955 }
956 
IsConnectedAndIdle() const957 bool MockClientSocket::IsConnectedAndIdle() const {
958   return connected_;
959 }
960 
GetPeerAddress(IPEndPoint * address) const961 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
962   if (!IsConnected())
963     return ERR_SOCKET_NOT_CONNECTED;
964   *address = peer_addr_;
965   return OK;
966 }
967 
GetLocalAddress(IPEndPoint * address) const968 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
969   *address = local_addr_;
970   return OK;
971 }
972 
NetLog() const973 const NetLogWithSource& MockClientSocket::NetLog() const {
974   return net_log_;
975 }
976 
GetNegotiatedProtocol() const977 NextProto MockClientSocket::GetNegotiatedProtocol() const {
978   return kProtoUnknown;
979 }
980 
981 MockClientSocket::~MockClientSocket() = default;
982 
RunCallbackAsync(CompletionOnceCallback callback,int result)983 void MockClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
984                                         int result) {
985   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
986       FROM_HERE,
987       base::BindOnce(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(),
988                      std::move(callback), result));
989 }
990 
RunCallback(CompletionOnceCallback callback,int result)991 void MockClientSocket::RunCallback(CompletionOnceCallback callback,
992                                    int result) {
993   std::move(callback).Run(result);
994 }
995 
MockTCPClientSocket(const AddressList & addresses,net::NetLog * net_log,SocketDataProvider * data)996 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
997                                          net::NetLog* net_log,
998                                          SocketDataProvider* data)
999     : MockClientSocket(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)),
1000       addresses_(addresses),
1001       data_(data),
1002       read_data_(SYNCHRONOUS, ERR_UNEXPECTED) {
1003   DCHECK(data_);
1004   peer_addr_ = data->connect_data().peer_addr;
1005   data_->Initialize(this);
1006   if (data_->expected_addresses()) {
1007     EXPECT_EQ(*data_->expected_addresses(), addresses);
1008   }
1009 }
1010 
~MockTCPClientSocket()1011 MockTCPClientSocket::~MockTCPClientSocket() {
1012   if (data_)
1013     data_->DetachSocket();
1014 }
1015 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1016 int MockTCPClientSocket::Read(IOBuffer* buf,
1017                               int buf_len,
1018                               CompletionOnceCallback callback) {
1019   // If the buffer is already in use, a read is already in progress!
1020   DCHECK(!pending_read_buf_);
1021   // Use base::Unretained() is safe because MockClientSocket::RunCallbackAsync()
1022   // takes a weak ptr of the base class, MockClientSocket.
1023   int rv = ReadIfReadyImpl(
1024       buf, buf_len,
1025       base::BindOnce(&MockTCPClientSocket::RetryRead, base::Unretained(this)));
1026   if (rv == ERR_IO_PENDING) {
1027     DCHECK(callback);
1028 
1029     pending_read_buf_ = buf;
1030     pending_read_buf_len_ = buf_len;
1031     pending_read_callback_ = std::move(callback);
1032   }
1033   return rv;
1034 }
1035 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1036 int MockTCPClientSocket::ReadIfReady(IOBuffer* buf,
1037                                      int buf_len,
1038                                      CompletionOnceCallback callback) {
1039   DCHECK(!pending_read_if_ready_callback_);
1040 
1041   if (!enable_read_if_ready_)
1042     return ERR_READ_IF_READY_NOT_IMPLEMENTED;
1043   return ReadIfReadyImpl(buf, buf_len, std::move(callback));
1044 }
1045 
CancelReadIfReady()1046 int MockTCPClientSocket::CancelReadIfReady() {
1047   DCHECK(pending_read_if_ready_callback_);
1048 
1049   pending_read_if_ready_callback_.Reset();
1050   data_->CancelPendingRead();
1051   return OK;
1052 }
1053 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1054 int MockTCPClientSocket::Write(
1055     IOBuffer* buf,
1056     int buf_len,
1057     CompletionOnceCallback callback,
1058     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1059   DCHECK(buf);
1060   DCHECK_GT(buf_len, 0);
1061 
1062   if (!connected_ || !data_)
1063     return ERR_UNEXPECTED;
1064 
1065   std::string data(buf->data(), buf_len);
1066   MockWriteResult write_result = data_->OnWrite(data);
1067 
1068   was_used_to_convey_data_ = true;
1069 
1070   if (write_result.result == ERR_CONNECTION_CLOSED) {
1071     // This MockWrite is just a marker to instruct us to set
1072     // peer_closed_connection_.
1073     peer_closed_connection_ = true;
1074   }
1075   // ERR_IO_PENDING is a signal that the socket data will call back
1076   // asynchronously later.
1077   if (write_result.result == ERR_IO_PENDING) {
1078     pending_write_callback_ = std::move(callback);
1079     return ERR_IO_PENDING;
1080   }
1081 
1082   if (write_result.mode == ASYNC) {
1083     RunCallbackAsync(std::move(callback), write_result.result);
1084     return ERR_IO_PENDING;
1085   }
1086 
1087   return write_result.result;
1088 }
1089 
SetReceiveBufferSize(int32_t size)1090 int MockTCPClientSocket::SetReceiveBufferSize(int32_t size) {
1091   if (!connected_)
1092     return net::ERR_UNEXPECTED;
1093   data_->set_receive_buffer_size(size);
1094   return data_->set_receive_buffer_size_result();
1095 }
1096 
SetSendBufferSize(int32_t size)1097 int MockTCPClientSocket::SetSendBufferSize(int32_t size) {
1098   if (!connected_)
1099     return net::ERR_UNEXPECTED;
1100   data_->set_send_buffer_size(size);
1101   return data_->set_send_buffer_size_result();
1102 }
1103 
SetNoDelay(bool no_delay)1104 bool MockTCPClientSocket::SetNoDelay(bool no_delay) {
1105   if (!connected_)
1106     return false;
1107   data_->set_no_delay(no_delay);
1108   return data_->set_no_delay_result();
1109 }
1110 
SetKeepAlive(bool enable,int delay)1111 bool MockTCPClientSocket::SetKeepAlive(bool enable, int delay) {
1112   if (!connected_)
1113     return false;
1114   data_->set_keep_alive(enable, delay);
1115   return data_->set_keep_alive_result();
1116 }
1117 
SetBeforeConnectCallback(const BeforeConnectCallback & before_connect_callback)1118 void MockTCPClientSocket::SetBeforeConnectCallback(
1119     const BeforeConnectCallback& before_connect_callback) {
1120   DCHECK(!before_connect_callback_);
1121   DCHECK(!connected_);
1122 
1123   before_connect_callback_ = before_connect_callback;
1124 }
1125 
Connect(CompletionOnceCallback callback)1126 int MockTCPClientSocket::Connect(CompletionOnceCallback callback) {
1127   if (!data_)
1128     return ERR_UNEXPECTED;
1129 
1130   if (connected_)
1131     return OK;
1132 
1133   // Setting socket options fails if not connected, so need to set this before
1134   // calling |before_connect_callback_|.
1135   connected_ = true;
1136 
1137   if (before_connect_callback_) {
1138     for (size_t index = 0; index < addresses_.size(); index++) {
1139       int result = before_connect_callback_.Run();
1140       if (data_->connect_data().first_attempt_fails && index == 0) {
1141         continue;
1142       }
1143       DCHECK_NE(result, ERR_IO_PENDING);
1144       if (result != net::OK) {
1145         connected_ = false;
1146         return result;
1147       }
1148       break;
1149     }
1150   }
1151 
1152   peer_closed_connection_ = false;
1153 
1154   int result = data_->connect_data().result;
1155   IoMode mode = data_->connect_data().mode;
1156   if (mode == SYNCHRONOUS)
1157     return result;
1158 
1159   DCHECK(callback);
1160 
1161   if (result == ERR_IO_PENDING)
1162     pending_connect_callback_ = std::move(callback);
1163   else
1164     RunCallbackAsync(std::move(callback), result);
1165   return ERR_IO_PENDING;
1166 }
1167 
Disconnect()1168 void MockTCPClientSocket::Disconnect() {
1169   MockClientSocket::Disconnect();
1170   pending_connect_callback_.Reset();
1171   pending_read_callback_.Reset();
1172 }
1173 
IsConnected() const1174 bool MockTCPClientSocket::IsConnected() const {
1175   if (!data_)
1176     return false;
1177   return connected_ && !peer_closed_connection_;
1178 }
1179 
IsConnectedAndIdle() const1180 bool MockTCPClientSocket::IsConnectedAndIdle() const {
1181   if (!data_)
1182     return false;
1183   return IsConnected() && data_->IsIdle();
1184 }
1185 
GetPeerAddress(IPEndPoint * address) const1186 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1187   if (addresses_.empty())
1188     return MockClientSocket::GetPeerAddress(address);
1189 
1190   if (data_->connect_data().first_attempt_fails) {
1191     DCHECK_GE(addresses_.size(), 2U);
1192     *address = addresses_[1];
1193   } else {
1194     *address = addresses_[0];
1195   }
1196   return OK;
1197 }
1198 
WasEverUsed() const1199 bool MockTCPClientSocket::WasEverUsed() const {
1200   return was_used_to_convey_data_;
1201 }
1202 
GetSSLInfo(SSLInfo * ssl_info)1203 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1204   return false;
1205 }
1206 
OnReadComplete(const MockRead & data)1207 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
1208   // If |data_| has been destroyed, safest to just do nothing.
1209   if (!data_)
1210     return;
1211 
1212   // There must be a read pending.
1213   DCHECK(pending_read_if_ready_callback_);
1214   // You can't complete a read with another ERR_IO_PENDING status code.
1215   DCHECK_NE(ERR_IO_PENDING, data.result);
1216   // Since we've been waiting for data, need_read_data_ should be true.
1217   DCHECK(need_read_data_);
1218 
1219   read_data_ = data;
1220   need_read_data_ = false;
1221 
1222   // The caller is simulating that this IO completes right now.  Don't
1223   // let CompleteRead() schedule a callback.
1224   read_data_.mode = SYNCHRONOUS;
1225   RunCallback(std::move(pending_read_if_ready_callback_),
1226               read_data_.result > 0 ? OK : read_data_.result);
1227 }
1228 
OnWriteComplete(int rv)1229 void MockTCPClientSocket::OnWriteComplete(int rv) {
1230   // If |data_| has been destroyed, safest to just do nothing.
1231   if (!data_)
1232     return;
1233 
1234   // There must be a read pending.
1235   DCHECK(!pending_write_callback_.is_null());
1236   RunCallback(std::move(pending_write_callback_), rv);
1237 }
1238 
OnConnectComplete(const MockConnect & data)1239 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
1240   // If |data_| has been destroyed, safest to just do nothing.
1241   if (!data_)
1242     return;
1243 
1244   RunCallback(std::move(pending_connect_callback_), data.result);
1245 }
1246 
OnDataProviderDestroyed()1247 void MockTCPClientSocket::OnDataProviderDestroyed() {
1248   data_ = nullptr;
1249 }
1250 
RetryRead(int rv)1251 void MockTCPClientSocket::RetryRead(int rv) {
1252   DCHECK(pending_read_callback_);
1253   DCHECK(pending_read_buf_.get());
1254   DCHECK_LT(0, pending_read_buf_len_);
1255 
1256   if (rv == OK) {
1257     rv = ReadIfReadyImpl(pending_read_buf_.get(), pending_read_buf_len_,
1258                          base::BindOnce(&MockTCPClientSocket::RetryRead,
1259                                         base::Unretained(this)));
1260     if (rv == ERR_IO_PENDING)
1261       return;
1262   }
1263   pending_read_buf_ = nullptr;
1264   pending_read_buf_len_ = 0;
1265   RunCallback(std::move(pending_read_callback_), rv);
1266 }
1267 
ReadIfReadyImpl(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1268 int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
1269                                          int buf_len,
1270                                          CompletionOnceCallback callback) {
1271   if (!connected_ || !data_)
1272     return ERR_UNEXPECTED;
1273 
1274   DCHECK(!pending_read_if_ready_callback_);
1275 
1276   if (need_read_data_) {
1277     read_data_ = data_->OnRead();
1278     if (read_data_.result == ERR_CONNECTION_CLOSED) {
1279       // This MockRead is just a marker to instruct us to set
1280       // peer_closed_connection_.
1281       peer_closed_connection_ = true;
1282     }
1283     if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
1284       // This MockRead is just a marker to instruct us to set
1285       // peer_closed_connection_.  Skip it and get the next one.
1286       read_data_ = data_->OnRead();
1287       peer_closed_connection_ = true;
1288     }
1289     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1290     // to complete the async IO manually later (via OnReadComplete).
1291     if (read_data_.result == ERR_IO_PENDING) {
1292       // We need to be using async IO in this case.
1293       DCHECK(!callback.is_null());
1294       pending_read_if_ready_callback_ = std::move(callback);
1295       return ERR_IO_PENDING;
1296     }
1297     need_read_data_ = false;
1298   }
1299 
1300   int result = read_data_.result;
1301   DCHECK_NE(ERR_IO_PENDING, result);
1302   if (read_data_.mode == ASYNC) {
1303     DCHECK(!callback.is_null());
1304     read_data_.mode = SYNCHRONOUS;
1305     pending_read_if_ready_callback_ = std::move(callback);
1306     // base::Unretained() is safe here because RunCallbackAsync will wrap it
1307     // with a callback associated with a weak ptr.
1308     RunCallbackAsync(
1309         base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback,
1310                        base::Unretained(this)),
1311         result);
1312     return ERR_IO_PENDING;
1313   }
1314 
1315   was_used_to_convey_data_ = true;
1316   if (read_data_.data) {
1317     if (read_data_.data_len - read_offset_ > 0) {
1318       result = std::min(buf_len, read_data_.data_len - read_offset_);
1319       memcpy(buf->data(), read_data_.data + read_offset_, result);
1320       read_offset_ += result;
1321       if (read_offset_ == read_data_.data_len) {
1322         need_read_data_ = true;
1323         read_offset_ = 0;
1324       }
1325     } else {
1326       result = 0;  // EOF
1327     }
1328   }
1329   return result;
1330 }
1331 
RunReadIfReadyCallback(int result)1332 void MockTCPClientSocket::RunReadIfReadyCallback(int result) {
1333   // If ReadIfReady is already canceled, do nothing.
1334   if (!pending_read_if_ready_callback_)
1335     return;
1336   std::move(pending_read_if_ready_callback_).Run(result);
1337 }
1338 
1339 // static
ConnectCallback(MockSSLClientSocket * ssl_client_socket,CompletionOnceCallback callback,int rv)1340 void MockSSLClientSocket::ConnectCallback(
1341     MockSSLClientSocket* ssl_client_socket,
1342     CompletionOnceCallback callback,
1343     int rv) {
1344   if (rv == OK)
1345     ssl_client_socket->connected_ = true;
1346   std::move(callback).Run(rv);
1347 }
1348 
MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLSocketDataProvider * data)1349 MockSSLClientSocket::MockSSLClientSocket(
1350     std::unique_ptr<StreamSocket> stream_socket,
1351     const HostPortPair& host_and_port,
1352     const SSLConfig& ssl_config,
1353     SSLSocketDataProvider* data)
1354     : net_log_(stream_socket->NetLog()),
1355       stream_socket_(std::move(stream_socket)),
1356       data_(data) {
1357   DCHECK(data_);
1358   peer_addr_ = data->connect.peer_addr;
1359 }
1360 
~MockSSLClientSocket()1361 MockSSLClientSocket::~MockSSLClientSocket() {
1362   Disconnect();
1363 }
1364 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1365 int MockSSLClientSocket::Read(IOBuffer* buf,
1366                               int buf_len,
1367                               CompletionOnceCallback callback) {
1368   return stream_socket_->Read(buf, buf_len, std::move(callback));
1369 }
1370 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1371 int MockSSLClientSocket::ReadIfReady(IOBuffer* buf,
1372                                      int buf_len,
1373                                      CompletionOnceCallback callback) {
1374   return stream_socket_->ReadIfReady(buf, buf_len, std::move(callback));
1375 }
1376 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1377 int MockSSLClientSocket::Write(
1378     IOBuffer* buf,
1379     int buf_len,
1380     CompletionOnceCallback callback,
1381     const NetworkTrafficAnnotationTag& traffic_annotation) {
1382   if (!data_->is_confirm_data_consumed)
1383     data_->write_called_before_confirm = true;
1384   return stream_socket_->Write(buf, buf_len, std::move(callback),
1385                                traffic_annotation);
1386 }
1387 
CancelReadIfReady()1388 int MockSSLClientSocket::CancelReadIfReady() {
1389   return stream_socket_->CancelReadIfReady();
1390 }
1391 
Connect(CompletionOnceCallback callback)1392 int MockSSLClientSocket::Connect(CompletionOnceCallback callback) {
1393   DCHECK(stream_socket_->IsConnected());
1394   data_->is_connect_data_consumed = true;
1395   if (data_->connect.result == OK)
1396     connected_ = true;
1397   RunClosureIfNonNull(std::move(data_->connect_callback));
1398   if (data_->connect.mode == ASYNC) {
1399     RunCallbackAsync(std::move(callback), data_->connect.result);
1400     return ERR_IO_PENDING;
1401   }
1402   return data_->connect.result;
1403 }
1404 
Disconnect()1405 void MockSSLClientSocket::Disconnect() {
1406   if (stream_socket_ != nullptr)
1407     stream_socket_->Disconnect();
1408 }
1409 
RunConfirmHandshakeCallback(CompletionOnceCallback callback,int result)1410 void MockSSLClientSocket::RunConfirmHandshakeCallback(
1411     CompletionOnceCallback callback,
1412     int result) {
1413   DCHECK(in_confirm_handshake_);
1414   in_confirm_handshake_ = false;
1415   data_->is_confirm_data_consumed = true;
1416   std::move(callback).Run(result);
1417 }
1418 
ConfirmHandshake(CompletionOnceCallback callback)1419 int MockSSLClientSocket::ConfirmHandshake(CompletionOnceCallback callback) {
1420   DCHECK(stream_socket_->IsConnected());
1421   DCHECK(!in_confirm_handshake_);
1422   if (data_->is_confirm_data_consumed)
1423     return data_->confirm.result;
1424   RunClosureIfNonNull(std::move(data_->confirm_callback));
1425   if (data_->confirm.mode == ASYNC) {
1426     in_confirm_handshake_ = true;
1427     RunCallbackAsync(
1428         base::BindOnce(&MockSSLClientSocket::RunConfirmHandshakeCallback,
1429                        base::Unretained(this), std::move(callback)),
1430         data_->confirm.result);
1431     return ERR_IO_PENDING;
1432   }
1433   data_->is_confirm_data_consumed = true;
1434   if (data_->confirm.result == ERR_IO_PENDING) {
1435     // `MockConfirm(SYNCHRONOUS, ERR_IO_PENDING)` means `ConfirmHandshake()`
1436     // never completes.
1437     in_confirm_handshake_ = true;
1438   }
1439   return data_->confirm.result;
1440 }
1441 
IsConnected() const1442 bool MockSSLClientSocket::IsConnected() const {
1443   return stream_socket_->IsConnected();
1444 }
1445 
IsConnectedAndIdle() const1446 bool MockSSLClientSocket::IsConnectedAndIdle() const {
1447   return stream_socket_->IsConnectedAndIdle();
1448 }
1449 
WasEverUsed() const1450 bool MockSSLClientSocket::WasEverUsed() const {
1451   return stream_socket_->WasEverUsed();
1452 }
1453 
GetLocalAddress(IPEndPoint * address) const1454 int MockSSLClientSocket::GetLocalAddress(IPEndPoint* address) const {
1455   *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
1456   return OK;
1457 }
1458 
GetPeerAddress(IPEndPoint * address) const1459 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
1460   return stream_socket_->GetPeerAddress(address);
1461 }
1462 
GetNegotiatedProtocol() const1463 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
1464   return data_->next_proto;
1465 }
1466 
1467 std::optional<std::string_view>
GetPeerApplicationSettings() const1468 MockSSLClientSocket::GetPeerApplicationSettings() const {
1469   return data_->peer_application_settings;
1470 }
1471 
GetSSLInfo(SSLInfo * requested_ssl_info)1472 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) {
1473   *requested_ssl_info = data_->ssl_info;
1474   return true;
1475 }
1476 
ApplySocketTag(const SocketTag & tag)1477 void MockSSLClientSocket::ApplySocketTag(const SocketTag& tag) {
1478   return stream_socket_->ApplySocketTag(tag);
1479 }
1480 
NetLog() const1481 const NetLogWithSource& MockSSLClientSocket::NetLog() const {
1482   return net_log_;
1483 }
1484 
GetTotalReceivedBytes() const1485 int64_t MockSSLClientSocket::GetTotalReceivedBytes() const {
1486   NOTIMPLEMENTED();
1487   return 0;
1488 }
1489 
GetTotalReceivedBytes() const1490 int64_t MockClientSocket::GetTotalReceivedBytes() const {
1491   NOTIMPLEMENTED();
1492   return 0;
1493 }
1494 
SetReceiveBufferSize(int32_t size)1495 int MockSSLClientSocket::SetReceiveBufferSize(int32_t size) {
1496   return OK;
1497 }
1498 
SetSendBufferSize(int32_t size)1499 int MockSSLClientSocket::SetSendBufferSize(int32_t size) {
1500   return OK;
1501 }
1502 
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info) const1503 void MockSSLClientSocket::GetSSLCertRequestInfo(
1504     SSLCertRequestInfo* cert_request_info) const {
1505   DCHECK(cert_request_info);
1506   if (data_->cert_request_info) {
1507     cert_request_info->host_and_port = data_->cert_request_info->host_and_port;
1508     cert_request_info->is_proxy = data_->cert_request_info->is_proxy;
1509     cert_request_info->cert_authorities =
1510         data_->cert_request_info->cert_authorities;
1511     cert_request_info->signature_algorithms =
1512         data_->cert_request_info->signature_algorithms;
1513   } else {
1514     cert_request_info->Reset();
1515   }
1516 }
1517 
ExportKeyingMaterial(std::string_view label,bool has_context,std::string_view context,unsigned char * out,unsigned int outlen)1518 int MockSSLClientSocket::ExportKeyingMaterial(std::string_view label,
1519                                               bool has_context,
1520                                               std::string_view context,
1521                                               unsigned char* out,
1522                                               unsigned int outlen) {
1523   memset(out, 'A', outlen);
1524   return OK;
1525 }
1526 
GetECHRetryConfigs()1527 std::vector<uint8_t> MockSSLClientSocket::GetECHRetryConfigs() {
1528   return data_->ech_retry_configs;
1529 }
1530 
RunCallbackAsync(CompletionOnceCallback callback,int result)1531 void MockSSLClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1532                                            int result) {
1533   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1534       FROM_HERE,
1535       base::BindOnce(&MockSSLClientSocket::RunCallback,
1536                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1537 }
1538 
RunCallback(CompletionOnceCallback callback,int result)1539 void MockSSLClientSocket::RunCallback(CompletionOnceCallback callback,
1540                                       int result) {
1541   std::move(callback).Run(result);
1542 }
1543 
OnReadComplete(const MockRead & data)1544 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1545   NOTIMPLEMENTED();
1546 }
1547 
OnWriteComplete(int rv)1548 void MockSSLClientSocket::OnWriteComplete(int rv) {
1549   NOTIMPLEMENTED();
1550 }
1551 
OnConnectComplete(const MockConnect & data)1552 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
1553   NOTIMPLEMENTED();
1554 }
1555 
MockUDPClientSocket(SocketDataProvider * data,net::NetLog * net_log)1556 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
1557                                          net::NetLog* net_log)
1558     : data_(data),
1559       read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
1560       source_host_(IPAddress(192, 0, 2, 33)),
1561       net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)) {
1562   if (data_) {
1563     data_->Initialize(this);
1564     peer_addr_ = data->connect_data().peer_addr;
1565   }
1566 }
1567 
~MockUDPClientSocket()1568 MockUDPClientSocket::~MockUDPClientSocket() {
1569   if (data_)
1570     data_->DetachSocket();
1571 }
1572 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1573 int MockUDPClientSocket::Read(IOBuffer* buf,
1574                               int buf_len,
1575                               CompletionOnceCallback callback) {
1576   DCHECK(callback);
1577 
1578   if (!connected_ || !data_)
1579     return ERR_UNEXPECTED;
1580   data_transferred_ = true;
1581 
1582   // If the buffer is already in use, a read is already in progress!
1583   DCHECK(!pending_read_buf_);
1584 
1585   // Store our async IO data.
1586   pending_read_buf_ = buf;
1587   pending_read_buf_len_ = buf_len;
1588   pending_read_callback_ = std::move(callback);
1589 
1590   if (need_read_data_) {
1591     read_data_ = data_->OnRead();
1592     last_tos_ = read_data_.tos;
1593     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1594     // to complete the async IO manually later (via OnReadComplete).
1595     if (read_data_.result == ERR_IO_PENDING) {
1596       // We need to be using async IO in this case.
1597       DCHECK(!pending_read_callback_.is_null());
1598       return ERR_IO_PENDING;
1599     }
1600     need_read_data_ = false;
1601   }
1602 
1603   return CompleteRead();
1604 }
1605 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1606 int MockUDPClientSocket::Write(
1607     IOBuffer* buf,
1608     int buf_len,
1609     CompletionOnceCallback callback,
1610     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1611   DCHECK(buf);
1612   DCHECK_GT(buf_len, 0);
1613   DCHECK(callback);
1614 
1615   if (!connected_ || !data_)
1616     return ERR_UNEXPECTED;
1617   data_transferred_ = true;
1618 
1619   std::string data(buf->data(), buf_len);
1620   MockWriteResult write_result = data_->OnWrite(data);
1621 
1622   // ERR_IO_PENDING is a signal that the socket data will call back
1623   // asynchronously.
1624   if (write_result.result == ERR_IO_PENDING) {
1625     pending_write_callback_ = std::move(callback);
1626     return ERR_IO_PENDING;
1627   }
1628   if (write_result.mode == ASYNC) {
1629     RunCallbackAsync(std::move(callback), write_result.result);
1630     return ERR_IO_PENDING;
1631   }
1632   return write_result.result;
1633 }
1634 
SetReceiveBufferSize(int32_t size)1635 int MockUDPClientSocket::SetReceiveBufferSize(int32_t size) {
1636   return OK;
1637 }
1638 
SetSendBufferSize(int32_t size)1639 int MockUDPClientSocket::SetSendBufferSize(int32_t size) {
1640   return OK;
1641 }
1642 
SetDoNotFragment()1643 int MockUDPClientSocket::SetDoNotFragment() {
1644   return OK;
1645 }
1646 
SetRecvTos()1647 int MockUDPClientSocket::SetRecvTos() {
1648   return OK;
1649 }
1650 
SetTos(DiffServCodePoint dscp,EcnCodePoint ecn)1651 int MockUDPClientSocket::SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) {
1652   return OK;
1653 }
1654 
Close()1655 void MockUDPClientSocket::Close() {
1656   connected_ = false;
1657 }
1658 
GetPeerAddress(IPEndPoint * address) const1659 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1660   if (!data_)
1661     return ERR_UNEXPECTED;
1662 
1663   *address = peer_addr_;
1664   return OK;
1665 }
1666 
GetLocalAddress(IPEndPoint * address) const1667 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
1668   *address = IPEndPoint(source_host_, source_port_);
1669   return OK;
1670 }
1671 
UseNonBlockingIO()1672 void MockUDPClientSocket::UseNonBlockingIO() {}
1673 
SetMulticastInterface(uint32_t interface_index)1674 int MockUDPClientSocket::SetMulticastInterface(uint32_t interface_index) {
1675   return OK;
1676 }
1677 
NetLog() const1678 const NetLogWithSource& MockUDPClientSocket::NetLog() const {
1679   return net_log_;
1680 }
1681 
Connect(const IPEndPoint & address)1682 int MockUDPClientSocket::Connect(const IPEndPoint& address) {
1683   if (!data_)
1684     return ERR_UNEXPECTED;
1685   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1686   connected_ = true;
1687   peer_addr_ = address;
1688   return data_->connect_data().result;
1689 }
1690 
ConnectUsingNetwork(handles::NetworkHandle network,const IPEndPoint & address)1691 int MockUDPClientSocket::ConnectUsingNetwork(handles::NetworkHandle network,
1692                                              const IPEndPoint& address) {
1693   DCHECK(!connected_);
1694   if (!data_)
1695     return ERR_UNEXPECTED;
1696   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1697   network_ = network;
1698   connected_ = true;
1699   peer_addr_ = address;
1700   return data_->connect_data().result;
1701 }
1702 
ConnectUsingDefaultNetwork(const IPEndPoint & address)1703 int MockUDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) {
1704   DCHECK(!connected_);
1705   if (!data_)
1706     return ERR_UNEXPECTED;
1707   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1708   network_ = kDefaultNetworkForTests;
1709   connected_ = true;
1710   peer_addr_ = address;
1711   return data_->connect_data().result;
1712 }
1713 
ConnectAsync(const IPEndPoint & address,CompletionOnceCallback callback)1714 int MockUDPClientSocket::ConnectAsync(const IPEndPoint& address,
1715                                       CompletionOnceCallback callback) {
1716   DCHECK(callback);
1717   if (!data_) {
1718     return ERR_UNEXPECTED;
1719   }
1720   connected_ = true;
1721   peer_addr_ = address;
1722   int result = data_->connect_data().result;
1723   IoMode mode = data_->connect_data().mode;
1724   if (mode == SYNCHRONOUS) {
1725     return result;
1726   }
1727   RunCallbackAsync(std::move(callback), result);
1728   return ERR_IO_PENDING;
1729 }
1730 
ConnectUsingNetworkAsync(handles::NetworkHandle network,const IPEndPoint & address,CompletionOnceCallback callback)1731 int MockUDPClientSocket::ConnectUsingNetworkAsync(
1732     handles::NetworkHandle network,
1733     const IPEndPoint& address,
1734     CompletionOnceCallback callback) {
1735   DCHECK(callback);
1736   DCHECK(!connected_);
1737   if (!data_)
1738     return ERR_UNEXPECTED;
1739   network_ = network;
1740   connected_ = true;
1741   peer_addr_ = address;
1742   int result = data_->connect_data().result;
1743   IoMode mode = data_->connect_data().mode;
1744   if (mode == SYNCHRONOUS) {
1745     return result;
1746   }
1747   RunCallbackAsync(std::move(callback), result);
1748   return ERR_IO_PENDING;
1749 }
1750 
ConnectUsingDefaultNetworkAsync(const IPEndPoint & address,CompletionOnceCallback callback)1751 int MockUDPClientSocket::ConnectUsingDefaultNetworkAsync(
1752     const IPEndPoint& address,
1753     CompletionOnceCallback callback) {
1754   DCHECK(!connected_);
1755   if (!data_)
1756     return ERR_UNEXPECTED;
1757   network_ = kDefaultNetworkForTests;
1758   connected_ = true;
1759   peer_addr_ = address;
1760   int result = data_->connect_data().result;
1761   IoMode mode = data_->connect_data().mode;
1762   if (mode == SYNCHRONOUS) {
1763     return result;
1764   }
1765   RunCallbackAsync(std::move(callback), result);
1766   return ERR_IO_PENDING;
1767 }
1768 
GetBoundNetwork() const1769 handles::NetworkHandle MockUDPClientSocket::GetBoundNetwork() const {
1770   return network_;
1771 }
1772 
ApplySocketTag(const SocketTag & tag)1773 void MockUDPClientSocket::ApplySocketTag(const SocketTag& tag) {
1774   tagged_before_data_transferred_ &= !data_transferred_ || tag == tag_;
1775   tag_ = tag;
1776 }
1777 
GetLastTos() const1778 DscpAndEcn MockUDPClientSocket::GetLastTos() const {
1779   return TosToDscpAndEcn(last_tos_);
1780 }
1781 
OnReadComplete(const MockRead & data)1782 void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
1783   if (!data_)
1784     return;
1785 
1786   // There must be a read pending.
1787   DCHECK(pending_read_buf_.get());
1788   DCHECK(pending_read_callback_);
1789   // You can't complete a read with another ERR_IO_PENDING status code.
1790   DCHECK_NE(ERR_IO_PENDING, data.result);
1791   // Since we've been waiting for data, need_read_data_ should be true.
1792   DCHECK(need_read_data_);
1793 
1794   read_data_ = data;
1795   last_tos_ = data.tos;
1796   need_read_data_ = false;
1797 
1798   // The caller is simulating that this IO completes right now.  Don't
1799   // let CompleteRead() schedule a callback.
1800   read_data_.mode = SYNCHRONOUS;
1801 
1802   CompletionOnceCallback callback = std::move(pending_read_callback_);
1803   int rv = CompleteRead();
1804   RunCallback(std::move(callback), rv);
1805 }
1806 
OnWriteComplete(int rv)1807 void MockUDPClientSocket::OnWriteComplete(int rv) {
1808   if (!data_)
1809     return;
1810 
1811   // There must be a read pending.
1812   DCHECK(!pending_write_callback_.is_null());
1813   RunCallback(std::move(pending_write_callback_), rv);
1814 }
1815 
OnConnectComplete(const MockConnect & data)1816 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
1817   NOTIMPLEMENTED();
1818 }
1819 
OnDataProviderDestroyed()1820 void MockUDPClientSocket::OnDataProviderDestroyed() {
1821   data_ = nullptr;
1822 }
1823 
CompleteRead()1824 int MockUDPClientSocket::CompleteRead() {
1825   DCHECK(pending_read_buf_.get());
1826   DCHECK(pending_read_buf_len_ > 0);
1827 
1828   // Save the pending async IO data and reset our |pending_| state.
1829   scoped_refptr<IOBuffer> buf = pending_read_buf_;
1830   int buf_len = pending_read_buf_len_;
1831   CompletionOnceCallback callback = std::move(pending_read_callback_);
1832   pending_read_buf_ = nullptr;
1833   pending_read_buf_len_ = 0;
1834 
1835   int result = read_data_.result;
1836   DCHECK(result != ERR_IO_PENDING);
1837 
1838   if (read_data_.data) {
1839     if (read_data_.data_len - read_offset_ > 0) {
1840       result = std::min(buf_len, read_data_.data_len - read_offset_);
1841       memcpy(buf->data(), read_data_.data + read_offset_, result);
1842       read_offset_ += result;
1843       if (read_offset_ == read_data_.data_len) {
1844         need_read_data_ = true;
1845         read_offset_ = 0;
1846       }
1847     } else {
1848       result = 0;  // EOF
1849     }
1850   }
1851 
1852   if (read_data_.mode == ASYNC) {
1853     DCHECK(!callback.is_null());
1854     RunCallbackAsync(std::move(callback), result);
1855     return ERR_IO_PENDING;
1856   }
1857   return result;
1858 }
1859 
RunCallbackAsync(CompletionOnceCallback callback,int result)1860 void MockUDPClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1861                                            int result) {
1862   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1863       FROM_HERE,
1864       base::BindOnce(&MockUDPClientSocket::RunCallback,
1865                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1866 }
1867 
RunCallback(CompletionOnceCallback callback,int result)1868 void MockUDPClientSocket::RunCallback(CompletionOnceCallback callback,
1869                                       int result) {
1870   std::move(callback).Run(result);
1871 }
1872 
TestSocketRequest(std::vector<raw_ptr<TestSocketRequest,VectorExperimental>> * request_order,size_t * completion_count)1873 TestSocketRequest::TestSocketRequest(
1874     std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>* request_order,
1875     size_t* completion_count)
1876     : request_order_(request_order), completion_count_(completion_count) {
1877   DCHECK(request_order);
1878   DCHECK(completion_count);
1879 }
1880 
1881 TestSocketRequest::~TestSocketRequest() = default;
1882 
OnComplete(int result)1883 void TestSocketRequest::OnComplete(int result) {
1884   SetResult(result);
1885   (*completion_count_)++;
1886   request_order_->push_back(this);
1887 }
1888 
1889 // static
1890 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
1891 
1892 // static
1893 const int ClientSocketPoolTest::kRequestNotFound = -2;
1894 
1895 ClientSocketPoolTest::ClientSocketPoolTest() = default;
1896 ClientSocketPoolTest::~ClientSocketPoolTest() = default;
1897 
GetOrderOfRequest(size_t index) const1898 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
1899   index--;
1900   if (index >= requests_.size())
1901     return kIndexOutOfBounds;
1902 
1903   for (size_t i = 0; i < request_order_.size(); i++)
1904     if (requests_[index].get() == request_order_[i])
1905       return i + 1;
1906 
1907   return kRequestNotFound;
1908 }
1909 
ReleaseOneConnection(KeepAlive keep_alive)1910 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
1911   for (std::unique_ptr<TestSocketRequest>& it : requests_) {
1912     if (it->handle()->is_initialized()) {
1913       if (keep_alive == NO_KEEP_ALIVE)
1914         it->handle()->socket()->Disconnect();
1915       it->handle()->Reset();
1916       base::RunLoop().RunUntilIdle();
1917       return true;
1918     }
1919   }
1920   return false;
1921 }
1922 
ReleaseAllConnections(KeepAlive keep_alive)1923 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
1924   bool released_one;
1925   do {
1926     released_one = ReleaseOneConnection(keep_alive);
1927   } while (released_one);
1928 }
1929 
MockConnectJob(std::unique_ptr<StreamSocket> socket,ClientSocketHandle * handle,const SocketTag & socket_tag,CompletionOnceCallback callback,RequestPriority priority)1930 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
1931     std::unique_ptr<StreamSocket> socket,
1932     ClientSocketHandle* handle,
1933     const SocketTag& socket_tag,
1934     CompletionOnceCallback callback,
1935     RequestPriority priority)
1936     : socket_(std::move(socket)),
1937       handle_(handle),
1938       socket_tag_(socket_tag),
1939       user_callback_(std::move(callback)),
1940       priority_(priority) {}
1941 
1942 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() = default;
1943 
Connect()1944 int MockTransportClientSocketPool::MockConnectJob::Connect() {
1945   socket_->ApplySocketTag(socket_tag_);
1946   int rv = socket_->Connect(
1947       base::BindOnce(&MockConnectJob::OnConnect, base::Unretained(this)));
1948   if (rv != ERR_IO_PENDING) {
1949     user_callback_.Reset();
1950     OnConnect(rv);
1951   }
1952   return rv;
1953 }
1954 
CancelHandle(const ClientSocketHandle * handle)1955 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
1956     const ClientSocketHandle* handle) {
1957   if (handle != handle_)
1958     return false;
1959   socket_.reset();
1960   handle_ = nullptr;
1961   user_callback_.Reset();
1962   return true;
1963 }
1964 
OnConnect(int rv)1965 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
1966   if (!socket_.get())
1967     return;
1968   if (rv == OK) {
1969     handle_->SetSocket(std::move(socket_));
1970 
1971     // Needed for socket pool tests that layer other sockets on top of mock
1972     // sockets.
1973     LoadTimingInfo::ConnectTiming connect_timing;
1974     base::TimeTicks now = base::TimeTicks::Now();
1975     connect_timing.domain_lookup_start = now;
1976     connect_timing.domain_lookup_end = now;
1977     connect_timing.connect_start = now;
1978     connect_timing.connect_end = now;
1979     handle_->set_connect_timing(connect_timing);
1980   } else {
1981     socket_.reset();
1982 
1983     // Needed to test copying of ConnectionAttempts in SSL ConnectJob.
1984     ConnectionAttempts attempts;
1985     attempts.push_back(ConnectionAttempt(IPEndPoint(), rv));
1986     handle_->set_connection_attempts(attempts);
1987   }
1988 
1989   handle_ = nullptr;
1990 
1991   if (!user_callback_.is_null()) {
1992     std::move(user_callback_).Run(rv);
1993   }
1994 }
1995 
MockTransportClientSocketPool(int max_sockets,int max_sockets_per_group,const CommonConnectJobParams * common_connect_job_params)1996 MockTransportClientSocketPool::MockTransportClientSocketPool(
1997     int max_sockets,
1998     int max_sockets_per_group,
1999     const CommonConnectJobParams* common_connect_job_params)
2000     : TransportClientSocketPool(
2001           max_sockets,
2002           max_sockets_per_group,
2003           base::Seconds(10) /* unused_idle_socket_timeout */,
2004           ProxyChain::Direct(),
2005           false /* is_for_websockets */,
2006           common_connect_job_params),
2007       client_socket_factory_(common_connect_job_params->client_socket_factory) {
2008 }
2009 
2010 MockTransportClientSocketPool::~MockTransportClientSocketPool() = default;
2011 
RequestSocket(const ClientSocketPool::GroupId & group_id,scoped_refptr<ClientSocketPool::SocketParams> socket_params,const std::optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,const SocketTag & socket_tag,RespectLimits respect_limits,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & on_auth_callback,const NetLogWithSource & net_log)2012 int MockTransportClientSocketPool::RequestSocket(
2013     const ClientSocketPool::GroupId& group_id,
2014     scoped_refptr<ClientSocketPool::SocketParams> socket_params,
2015     const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
2016     RequestPriority priority,
2017     const SocketTag& socket_tag,
2018     RespectLimits respect_limits,
2019     ClientSocketHandle* handle,
2020     CompletionOnceCallback callback,
2021     const ProxyAuthCallback& on_auth_callback,
2022     const NetLogWithSource& net_log) {
2023   last_request_priority_ = priority;
2024   std::unique_ptr<StreamSocket> socket =
2025       client_socket_factory_->CreateTransportClientSocket(
2026           AddressList(), nullptr, nullptr, net_log.net_log(), NetLogSource());
2027   auto job = std::make_unique<MockConnectJob>(
2028       std::move(socket), handle, socket_tag, std::move(callback), priority);
2029   auto* job_ptr = job.get();
2030   job_list_.push_back(std::move(job));
2031   handle->set_group_generation(1);
2032   return job_ptr->Connect();
2033 }
2034 
SetPriority(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,RequestPriority priority)2035 void MockTransportClientSocketPool::SetPriority(
2036     const ClientSocketPool::GroupId& group_id,
2037     ClientSocketHandle* handle,
2038     RequestPriority priority) {
2039   for (auto& job : job_list_) {
2040     if (job->handle() == handle) {
2041       job->set_priority(priority);
2042       return;
2043     }
2044   }
2045   NOTREACHED();
2046 }
2047 
CancelRequest(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,bool cancel_connect_job)2048 void MockTransportClientSocketPool::CancelRequest(
2049     const ClientSocketPool::GroupId& group_id,
2050     ClientSocketHandle* handle,
2051     bool cancel_connect_job) {
2052   for (std::unique_ptr<MockConnectJob>& it : job_list_) {
2053     if (it->CancelHandle(handle)) {
2054       cancel_count_++;
2055       break;
2056     }
2057   }
2058 }
2059 
ReleaseSocket(const ClientSocketPool::GroupId & group_id,std::unique_ptr<StreamSocket> socket,int64_t generation)2060 void MockTransportClientSocketPool::ReleaseSocket(
2061     const ClientSocketPool::GroupId& group_id,
2062     std::unique_ptr<StreamSocket> socket,
2063     int64_t generation) {
2064   EXPECT_EQ(1, generation);
2065   release_count_++;
2066 }
2067 
WrappedStreamSocket(std::unique_ptr<StreamSocket> transport)2068 WrappedStreamSocket::WrappedStreamSocket(
2069     std::unique_ptr<StreamSocket> transport)
2070     : transport_(std::move(transport)) {}
2071 WrappedStreamSocket::~WrappedStreamSocket() = default;
2072 
Bind(const net::IPEndPoint & local_addr)2073 int WrappedStreamSocket::Bind(const net::IPEndPoint& local_addr) {
2074   NOTREACHED();
2075   return ERR_FAILED;
2076 }
2077 
Connect(CompletionOnceCallback callback)2078 int WrappedStreamSocket::Connect(CompletionOnceCallback callback) {
2079   return transport_->Connect(std::move(callback));
2080 }
2081 
Disconnect()2082 void WrappedStreamSocket::Disconnect() {
2083   transport_->Disconnect();
2084 }
2085 
IsConnected() const2086 bool WrappedStreamSocket::IsConnected() const {
2087   return transport_->IsConnected();
2088 }
2089 
IsConnectedAndIdle() const2090 bool WrappedStreamSocket::IsConnectedAndIdle() const {
2091   return transport_->IsConnectedAndIdle();
2092 }
2093 
GetPeerAddress(IPEndPoint * address) const2094 int WrappedStreamSocket::GetPeerAddress(IPEndPoint* address) const {
2095   return transport_->GetPeerAddress(address);
2096 }
2097 
GetLocalAddress(IPEndPoint * address) const2098 int WrappedStreamSocket::GetLocalAddress(IPEndPoint* address) const {
2099   return transport_->GetLocalAddress(address);
2100 }
2101 
NetLog() const2102 const NetLogWithSource& WrappedStreamSocket::NetLog() const {
2103   return transport_->NetLog();
2104 }
2105 
WasEverUsed() const2106 bool WrappedStreamSocket::WasEverUsed() const {
2107   return transport_->WasEverUsed();
2108 }
2109 
GetNegotiatedProtocol() const2110 NextProto WrappedStreamSocket::GetNegotiatedProtocol() const {
2111   return transport_->GetNegotiatedProtocol();
2112 }
2113 
GetSSLInfo(SSLInfo * ssl_info)2114 bool WrappedStreamSocket::GetSSLInfo(SSLInfo* ssl_info) {
2115   return transport_->GetSSLInfo(ssl_info);
2116 }
2117 
GetTotalReceivedBytes() const2118 int64_t WrappedStreamSocket::GetTotalReceivedBytes() const {
2119   return transport_->GetTotalReceivedBytes();
2120 }
2121 
ApplySocketTag(const SocketTag & tag)2122 void WrappedStreamSocket::ApplySocketTag(const SocketTag& tag) {
2123   transport_->ApplySocketTag(tag);
2124 }
2125 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2126 int WrappedStreamSocket::Read(IOBuffer* buf,
2127                               int buf_len,
2128                               CompletionOnceCallback callback) {
2129   return transport_->Read(buf, buf_len, std::move(callback));
2130 }
2131 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2132 int WrappedStreamSocket::ReadIfReady(IOBuffer* buf,
2133                                      int buf_len,
2134                                      CompletionOnceCallback callback) {
2135   return transport_->ReadIfReady(buf, buf_len, std::move((callback)));
2136 }
2137 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)2138 int WrappedStreamSocket::Write(
2139     IOBuffer* buf,
2140     int buf_len,
2141     CompletionOnceCallback callback,
2142     const NetworkTrafficAnnotationTag& traffic_annotation) {
2143   return transport_->Write(buf, buf_len, std::move(callback),
2144                            TRAFFIC_ANNOTATION_FOR_TESTS);
2145 }
2146 
SetReceiveBufferSize(int32_t size)2147 int WrappedStreamSocket::SetReceiveBufferSize(int32_t size) {
2148   return transport_->SetReceiveBufferSize(size);
2149 }
2150 
SetSendBufferSize(int32_t size)2151 int WrappedStreamSocket::SetSendBufferSize(int32_t size) {
2152   return transport_->SetSendBufferSize(size);
2153 }
2154 
Connect(CompletionOnceCallback callback)2155 int MockTaggingStreamSocket::Connect(CompletionOnceCallback callback) {
2156   connected_ = true;
2157   return WrappedStreamSocket::Connect(std::move(callback));
2158 }
2159 
ApplySocketTag(const SocketTag & tag)2160 void MockTaggingStreamSocket::ApplySocketTag(const SocketTag& tag) {
2161   tagged_before_connected_ &= !connected_ || tag == tag_;
2162   tag_ = tag;
2163   transport_->ApplySocketTag(tag);
2164 }
2165 
2166 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)2167 MockTaggingClientSocketFactory::CreateTransportClientSocket(
2168     const AddressList& addresses,
2169     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
2170     NetworkQualityEstimator* network_quality_estimator,
2171     NetLog* net_log,
2172     const NetLogSource& source) {
2173   auto socket = std::make_unique<MockTaggingStreamSocket>(
2174       MockClientSocketFactory::CreateTransportClientSocket(
2175           addresses, std::move(socket_performance_watcher),
2176           network_quality_estimator, net_log, source));
2177   tcp_socket_ = socket.get();
2178   return std::move(socket);
2179 }
2180 
2181 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)2182 MockTaggingClientSocketFactory::CreateDatagramClientSocket(
2183     DatagramSocket::BindType bind_type,
2184     NetLog* net_log,
2185     const NetLogSource& source) {
2186   std::unique_ptr<DatagramClientSocket> socket(
2187       MockClientSocketFactory::CreateDatagramClientSocket(bind_type, net_log,
2188                                                           source));
2189   udp_socket_ = static_cast<MockUDPClientSocket*>(socket.get());
2190   return socket;
2191 }
2192 
2193 const char kSOCKS4TestHost[] = "127.0.0.1";
2194 const int kSOCKS4TestPort = 80;
2195 
2196 const char kSOCKS4OkRequestLocalHostPort80[] = {0x04, 0x01, 0x00, 0x50, 127,
2197                                                 0,    0,    1,    0};
2198 const int kSOCKS4OkRequestLocalHostPort80Length =
2199     std::size(kSOCKS4OkRequestLocalHostPort80);
2200 
2201 const char kSOCKS4OkReply[] = {0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0};
2202 const int kSOCKS4OkReplyLength = std::size(kSOCKS4OkReply);
2203 
2204 const char kSOCKS5TestHost[] = "host";
2205 const int kSOCKS5TestPort = 80;
2206 
2207 const char kSOCKS5GreetRequest[] = {0x05, 0x01, 0x00};
2208 const int kSOCKS5GreetRequestLength = std::size(kSOCKS5GreetRequest);
2209 
2210 const char kSOCKS5GreetResponse[] = {0x05, 0x00};
2211 const int kSOCKS5GreetResponseLength = std::size(kSOCKS5GreetResponse);
2212 
2213 const char kSOCKS5OkRequest[] = {0x05, 0x01, 0x00, 0x03, 0x04, 'h',
2214                                  'o',  's',  't',  0x00, 0x50};
2215 const int kSOCKS5OkRequestLength = std::size(kSOCKS5OkRequest);
2216 
2217 const char kSOCKS5OkResponse[] = {0x05, 0x00, 0x00, 0x01, 127,
2218                                   0,    0,    1,    0x00, 0x50};
2219 const int kSOCKS5OkResponseLength = std::size(kSOCKS5OkResponse);
2220 
CountReadBytes(base::span<const MockRead> reads)2221 int64_t CountReadBytes(base::span<const MockRead> reads) {
2222   int64_t total = 0;
2223   for (const MockRead& read : reads)
2224     total += read.data_len;
2225   return total;
2226 }
2227 
CountWriteBytes(base::span<const MockWrite> writes)2228 int64_t CountWriteBytes(base::span<const MockWrite> writes) {
2229   int64_t total = 0;
2230   for (const MockWrite& write : writes)
2231     total += write.data_len;
2232   return total;
2233 }
2234 
2235 #if BUILDFLAG(IS_ANDROID)
CanGetTaggedBytes()2236 bool CanGetTaggedBytes() {
2237   // In Android P, /proc/net/xt_qtaguid/stats is no longer guaranteed to be
2238   // present, and has been replaced with eBPF Traffic Monitoring in netd. See:
2239   // https://source.android.com/devices/tech/datausage/ebpf-traffic-monitor
2240   //
2241   // To read traffic statistics from netd, apps should use the API
2242   // NetworkStatsManager.queryDetailsForUidTag(). But this API does not provide
2243   // statistics for local traffic, only mobile and WiFi traffic, so it would not
2244   // work in tests that spin up a local server. So for now, GetTaggedBytes is
2245   // only supported on Android releases older than P.
2246   return base::android::BuildInfo::GetInstance()->sdk_int() <
2247          base::android::SDK_VERSION_P;
2248 }
2249 
GetTaggedBytes(int32_t expected_tag)2250 uint64_t GetTaggedBytes(int32_t expected_tag) {
2251   EXPECT_TRUE(CanGetTaggedBytes());
2252 
2253   // To determine how many bytes the system saw with a particular tag read
2254   // the /proc/net/xt_qtaguid/stats file which contains the kernel's
2255   // dump of all the UIDs and their tags sent and received bytes.
2256   uint64_t bytes = 0;
2257   std::string contents;
2258   EXPECT_TRUE(base::ReadFileToString(
2259       base::FilePath::FromUTF8Unsafe("/proc/net/xt_qtaguid/stats"), &contents));
2260   for (size_t i = contents.find('\n');  // Skip first line which is headers.
2261        i != std::string::npos && i < contents.length();) {
2262     uint64_t tag, rx_bytes;
2263     uid_t uid;
2264     int n;
2265     // Parse out the numbers we care about. For reference here's the column
2266     // headers:
2267     // idx iface acct_tag_hex uid_tag_int cnt_set rx_bytes rx_packets tx_bytes
2268     // tx_packets rx_tcp_bytes rx_tcp_packets rx_udp_bytes rx_udp_packets
2269     // rx_other_bytes rx_other_packets tx_tcp_bytes tx_tcp_packets tx_udp_bytes
2270     // tx_udp_packets tx_other_bytes tx_other_packets
2271     EXPECT_EQ(sscanf(contents.c_str() + i,
2272                      "%*d %*s 0x%" SCNx64 " %d %*d %" SCNu64
2273                      " %*d %*d %*d %*d %*d %*d %*d %*d "
2274                      "%*d %*d %*d %*d %*d %*d %*d%n",
2275                      &tag, &uid, &rx_bytes, &n),
2276               3);
2277     // If this line matches our UID and |expected_tag| then add it to the total.
2278     if (uid == getuid() && (int32_t)(tag >> 32) == expected_tag) {
2279       bytes += rx_bytes;
2280     }
2281     // Move |i| to the next line.
2282     i += n + 1;
2283   }
2284   return bytes;
2285 }
2286 #endif
2287 
2288 }  // namespace net
2289