// Copyright 2012 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/socket/socket_test_util.h" #include "base/memory/raw_ptr.h" #include // For SCNx64 #include #include #include #include #include #include #include #include #include "base/compiler_specific.h" #include "base/files/file_util.h" #include "base/functional/bind.h" #include "base/functional/callback_helpers.h" #include "base/location.h" #include "base/logging.h" #include "base/rand_util.h" #include "base/ranges/algorithm.h" #include "base/run_loop.h" #include "base/task/single_thread_task_runner.h" #include "base/time/time.h" #include "build/build_config.h" #include "net/base/address_family.h" #include "net/base/address_list.h" #include "net/base/auth.h" #include "net/base/hex_utils.h" #include "net/base/ip_address.h" #include "net/base/load_timing_info.h" #include "net/base/proxy_server.h" #include "net/http/http_network_session.h" #include "net/http/http_request_headers.h" #include "net/http/http_response_headers.h" #include "net/log/net_log_source.h" #include "net/log/net_log_source_type.h" #include "net/socket/connect_job.h" #include "net/socket/socket.h" #include "net/socket/stream_socket.h" #include "net/socket/websocket_endpoint_lock_manager.h" #include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_connection_status_flags.h" #include "net/ssl/ssl_info.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "testing/gtest/include/gtest/gtest.h" #include "third_party/abseil-cpp/absl/strings/ascii.h" #if BUILDFLAG(IS_ANDROID) #include "base/android/build_info.h" #endif #define NET_TRACE(level, s) VLOG(level) << s << __FUNCTION__ << "() " namespace net { namespace { inline char AsciifyHigh(char x) { char nybble = static_cast((x >> 4) & 0x0F); return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10); } inline char AsciifyLow(char x) { char nybble = static_cast((x >> 0) & 0x0F); return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10); } inline char Asciify(char x) { return absl::ascii_isprint(static_cast(x)) ? x : '.'; } void DumpData(const char* data, int data_len) { if (logging::LOGGING_INFO < logging::GetMinLogLevel()) { return; } DVLOG(1) << "Length: " << data_len; const char* pfx = "Data: "; if (!data || (data_len <= 0)) { DVLOG(1) << pfx << ""; } else { int i; for (i = 0; i <= (data_len - 4); i += 4) { DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2]) << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3]) << " '" << Asciify(data[i + 0]) << Asciify(data[i + 1]) << Asciify(data[i + 2]) << Asciify(data[i + 3]) << "'"; pfx = " "; } // Take care of any 'trailing' bytes, if data_len was not a multiple of 4. switch (data_len - i) { case 3: DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2]) << " '" << Asciify(data[i + 0]) << Asciify(data[i + 1]) << Asciify(data[i + 2]) << " '"; break; case 2: DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) << " '" << Asciify(data[i + 0]) << Asciify(data[i + 1]) << " '"; break; case 1: DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) << " '" << Asciify(data[i + 0]) << " '"; break; } } } template void DumpMockReadWrite(const MockReadWrite& r) { if (logging::LOGGING_INFO < logging::GetMinLogLevel()) { return; } DVLOG(1) << "Async: " << (r.mode == ASYNC) << "\nResult: " << r.result; DumpData(r.data, r.data_len); const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : ""; DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop; } void RunClosureIfNonNull(base::OnceClosure closure) { if (!closure.is_null()) { std::move(closure).Run(); } } } // namespace MockConnect::MockConnect() : mode(ASYNC), result(OK) { peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0); } MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) { peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0); } MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) : mode(io_mode), result(r), peer_addr(addr) {} MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails) : mode(io_mode), result(r), peer_addr(addr), first_attempt_fails(first_attempt_fails) {} MockConnect::~MockConnect() = default; MockConfirm::MockConfirm() : mode(SYNCHRONOUS), result(OK) {} MockConfirm::MockConfirm(IoMode io_mode, int r) : mode(io_mode), result(r) {} MockConfirm::~MockConfirm() = default; bool SocketDataProvider::IsIdle() const { return true; } void SocketDataProvider::Initialize(AsyncSocket* socket) { CHECK(!socket_); CHECK(socket); socket_ = socket; Reset(); } void SocketDataProvider::DetachSocket() { CHECK(socket_); socket_ = nullptr; } SocketDataProvider::SocketDataProvider() = default; SocketDataProvider::~SocketDataProvider() { if (socket_) socket_->OnDataProviderDestroyed(); } StaticSocketDataHelper::StaticSocketDataHelper( base::span reads, base::span writes) : reads_(reads), writes_(writes) {} StaticSocketDataHelper::~StaticSocketDataHelper() = default; const MockRead& StaticSocketDataHelper::PeekRead() const { CHECK(!AllReadDataConsumed()); return reads_[read_index_]; } const MockWrite& StaticSocketDataHelper::PeekWrite() const { CHECK(!AllWriteDataConsumed()); return writes_[write_index_]; } const MockRead& StaticSocketDataHelper::AdvanceRead() { CHECK(!AllReadDataConsumed()); return reads_[read_index_++]; } const MockWrite& StaticSocketDataHelper::AdvanceWrite() { CHECK(!AllWriteDataConsumed()); return writes_[write_index_++]; } void StaticSocketDataHelper::Reset() { read_index_ = 0; write_index_ = 0; } bool StaticSocketDataHelper::VerifyWriteData(const std::string& data, SocketDataPrinter* printer) { CHECK(!AllWriteDataConsumed()); // Check that the actual data matches the expectations, skipping over any // pause events. const MockWrite& next_write = PeekRealWrite(); if (!next_write.data) return true; // Note: Partial writes are supported here. If the expected data // is a match, but shorter than the write actually written, that is legal. // Example: // Application writes "foobarbaz" (9 bytes) // Expected write was "foo" (3 bytes) // This is a success, and the function returns true. std::string expected_data(next_write.data, next_write.data_len); std::string actual_data(data.substr(0, next_write.data_len)); if (printer) { EXPECT_TRUE(actual_data == expected_data) << "Actual formatted write data:\n" << printer->PrintWrite(data) << "Expected formatted write data:\n" << printer->PrintWrite(expected_data) << "Actual raw write data:\n" << HexDump(data) << "Expected raw write data:\n" << HexDump(expected_data); } else { EXPECT_TRUE(actual_data == expected_data) << "Actual write data:\n" << HexDump(data) << "Expected write data:\n" << HexDump(expected_data); } return expected_data == actual_data; } void StaticSocketDataHelper::ExpectAllReadDataConsumed( SocketDataPrinter* printer) const { if (AllReadDataConsumed()) { return; } std::ostringstream msg; if (read_index_ < read_count()) { msg << "Unconsumed reads:\n"; for (size_t i = read_index_; i < read_count(); i++) { msg << (reads_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockRead seq " << reads_[i].sequence_number << ":\n"; if (reads_[i].result != OK) { msg << "Result: " << reads_[i].result << "\n"; } if (reads_[i].data) { std::string data(reads_[i].data, reads_[i].data_len); if (printer) { msg << printer->PrintWrite(data); } msg << HexDump(data); } } } EXPECT_TRUE(AllReadDataConsumed()) << msg.str(); } void StaticSocketDataHelper::ExpectAllWriteDataConsumed( SocketDataPrinter* printer) const { if (AllWriteDataConsumed()) { return; } std::ostringstream msg; if (write_index_ < write_count()) { msg << "Unconsumed writes:\n"; for (size_t i = write_index_; i < write_count(); i++) { msg << (writes_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockWrite seq " << writes_[i].sequence_number << ":\n"; if (writes_[i].result != OK) { msg << "Result: " << writes_[i].result << "\n"; } if (writes_[i].data) { std::string data(writes_[i].data, writes_[i].data_len); if (printer) { msg << printer->PrintWrite(data); } msg << HexDump(data); } } } EXPECT_TRUE(AllWriteDataConsumed()) << msg.str(); } const MockWrite& StaticSocketDataHelper::PeekRealWrite() const { for (size_t i = write_index_; i < write_count(); i++) { if (writes_[i].mode != ASYNC || writes_[i].result != ERR_IO_PENDING) return writes_[i]; } CHECK(false) << "No write data available."; return writes_[0]; // Avoid warning about unreachable missing return. } StaticSocketDataProvider::StaticSocketDataProvider() : StaticSocketDataProvider(base::span(), base::span()) {} StaticSocketDataProvider::StaticSocketDataProvider( base::span reads, base::span writes) : helper_(reads, writes) {} StaticSocketDataProvider::~StaticSocketDataProvider() = default; void StaticSocketDataProvider::Pause() { paused_ = true; } void StaticSocketDataProvider::Resume() { paused_ = false; } MockRead StaticSocketDataProvider::OnRead() { if (AllReadDataConsumed()) { const net::MockRead pending_read(net::SYNCHRONOUS, net::ERR_IO_PENDING); return pending_read; } return helper_.AdvanceRead(); } MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { if (helper_.write_count() == 0) { // Not using mock writes; succeed synchronously. return MockWriteResult(SYNCHRONOUS, data.length()); } if (printer_) { EXPECT_FALSE(helper_.AllWriteDataConsumed()) << "No more mock data to match write:\nFormatted write data:\n" << printer_->PrintWrite(data) << "Raw write data:\n" << HexDump(data); } else { EXPECT_FALSE(helper_.AllWriteDataConsumed()) << "No more mock data to match write:\nRaw write data:\n" << HexDump(data); } if (helper_.AllWriteDataConsumed()) { return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); } // Check that what we are writing matches the expectation. // Then give the mocked return value. if (!helper_.VerifyWriteData(data, printer_)) return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); const MockWrite& next_write = helper_.AdvanceWrite(); // In the case that the write was successful, return the number of bytes // written. Otherwise return the error code. int result = next_write.result == OK ? next_write.data_len : next_write.result; return MockWriteResult(next_write.mode, result); } bool StaticSocketDataProvider::AllReadDataConsumed() const { return paused_ || helper_.AllReadDataConsumed(); } bool StaticSocketDataProvider::AllWriteDataConsumed() const { return helper_.AllWriteDataConsumed(); } void StaticSocketDataProvider::Reset() { helper_.Reset(); } SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result) : connect(mode, result), expected_ssl_version_min(kDefaultSSLVersionMin), expected_ssl_version_max(kDefaultSSLVersionMax) { SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3, &ssl_info.connection_status); // Set to TLS_CHACHA20_POLY1305_SHA256 SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status); } SSLSocketDataProvider::SSLSocketDataProvider( const SSLSocketDataProvider& other) = default; SSLSocketDataProvider::~SSLSocketDataProvider() = default; SequencedSocketData::SequencedSocketData() : SequencedSocketData(base::span(), base::span()) {} SequencedSocketData::SequencedSocketData(base::span reads, base::span writes) : helper_(reads, writes) { // Check that reads and writes have a contiguous set of sequence numbers // starting from 0 and working their way up, with no repeats and skipping // no values. int next_sequence_number = 0; bool last_event_was_pause = false; auto next_read = reads.begin(); auto next_write = writes.begin(); while (next_read != reads.end() || next_write != writes.end()) { if (next_read != reads.end() && next_read->sequence_number == next_sequence_number) { // Check if this is a pause. if (next_read->mode == ASYNC && next_read->result == ERR_IO_PENDING) { CHECK(!last_event_was_pause) << "Two pauses in a row are not allowed: " << next_sequence_number; last_event_was_pause = true; } else if (last_event_was_pause) { CHECK_EQ(ASYNC, next_read->mode) << "A sync event after a pause makes no sense: " << next_sequence_number; CHECK_NE(ERR_IO_PENDING, next_read->result) << "A pause event after a pause makes no sense: " << next_sequence_number; last_event_was_pause = false; } ++next_read; ++next_sequence_number; continue; } if (next_write != writes.end() && next_write->sequence_number == next_sequence_number) { // Check if this is a pause. if (next_write->mode == ASYNC && next_write->result == ERR_IO_PENDING) { CHECK(!last_event_was_pause) << "Two pauses in a row are not allowed: " << next_sequence_number; last_event_was_pause = true; } else if (last_event_was_pause) { CHECK_EQ(ASYNC, next_write->mode) << "A sync event after a pause makes no sense: " << next_sequence_number; CHECK_NE(ERR_IO_PENDING, next_write->result) << "A pause event after a pause makes no sense: " << next_sequence_number; last_event_was_pause = false; } ++next_write; ++next_sequence_number; continue; } if (next_write != writes.end()) { CHECK(false) << "Sequence number " << next_write->sequence_number << " not found where expected: " << next_sequence_number; } else { CHECK(false) << "Too few writes, next expected sequence number: " << next_sequence_number; } return; } // Last event must not be a pause. For the final event to indicate the // operation never completes, it should be SYNCHRONOUS and return // ERR_IO_PENDING. CHECK(!last_event_was_pause); CHECK(next_read == reads.end()); CHECK(next_write == writes.end()); } SequencedSocketData::SequencedSocketData(const MockConnect& connect, base::span reads, base::span writes) : SequencedSocketData(reads, writes) { set_connect_data(connect); } MockRead SequencedSocketData::OnRead() { CHECK_EQ(IoState::kIdle, read_state_); CHECK(!helper_.AllReadDataConsumed()) << "Application tried to read but there is no read data left"; NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_; const MockRead& next_read = helper_.PeekRead(); NET_TRACE(1, " *** ") << "next_read: " << next_read.sequence_number; CHECK_GE(next_read.sequence_number, sequence_number_); if (next_read.sequence_number <= sequence_number_) { if (next_read.mode == SYNCHRONOUS) { NET_TRACE(1, " *** ") << "Returning synchronously"; DumpMockReadWrite(next_read); helper_.AdvanceRead(); ++sequence_number_; MaybePostWriteCompleteTask(); return next_read; } // If the result is ERR_IO_PENDING, then pause. if (next_read.result == ERR_IO_PENDING) { NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_; read_state_ = IoState::kPaused; if (run_until_paused_run_loop_) run_until_paused_run_loop_->Quit(); return MockRead(SYNCHRONOUS, ERR_IO_PENDING); } base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete, weak_factory_.GetWeakPtr())); CHECK_NE(IoState::kCompleting, write_state_); read_state_ = IoState::kCompleting; } else if (next_read.mode == SYNCHRONOUS) { ADD_FAILURE() << "Unable to perform synchronous IO while stopped"; return MockRead(SYNCHRONOUS, ERR_UNEXPECTED); } else { NET_TRACE(1, " *** ") << "Waiting for write to trigger read"; read_state_ = IoState::kPending; } return MockRead(SYNCHRONOUS, ERR_IO_PENDING); } MockWriteResult SequencedSocketData::OnWrite(const std::string& data) { CHECK_EQ(IoState::kIdle, write_state_); if (printer_) { CHECK(!helper_.AllWriteDataConsumed()) << "\nNo more mock data to match write:\nFormatted write data:\n" << printer_->PrintWrite(data) << "Raw write data:\n" << HexDump(data); } else { CHECK(!helper_.AllWriteDataConsumed()) << "\nNo more mock data to match write:\nRaw write data:\n" << HexDump(data); } NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_; const MockWrite& next_write = helper_.PeekWrite(); NET_TRACE(1, " *** ") << "next_write: " << next_write.sequence_number; CHECK_GE(next_write.sequence_number, sequence_number_); if (!helper_.VerifyWriteData(data, printer_)) return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); if (next_write.sequence_number <= sequence_number_) { if (next_write.mode == SYNCHRONOUS) { helper_.AdvanceWrite(); ++sequence_number_; MaybePostReadCompleteTask(); // In the case that the write was successful, return the number of bytes // written. Otherwise return the error code. int rv = next_write.result != OK ? next_write.result : next_write.data_len; NET_TRACE(1, " *** ") << "Returning synchronously"; return MockWriteResult(SYNCHRONOUS, rv); } // If the result is ERR_IO_PENDING, then pause. if (next_write.result == ERR_IO_PENDING) { NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_; write_state_ = IoState::kPaused; if (run_until_paused_run_loop_) run_until_paused_run_loop_->Quit(); return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING); } NET_TRACE(1, " *** ") << "Posting task to complete write"; base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete, weak_factory_.GetWeakPtr())); CHECK_NE(IoState::kCompleting, read_state_); write_state_ = IoState::kCompleting; } else if (next_write.mode == SYNCHRONOUS) { ADD_FAILURE() << "Unable to perform synchronous IO while stopped"; return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); } else { NET_TRACE(1, " *** ") << "Waiting for read to trigger write"; write_state_ = IoState::kPending; } return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING); } bool SequencedSocketData::AllReadDataConsumed() const { return helper_.AllReadDataConsumed(); } void SequencedSocketData::CancelPendingRead() { DCHECK_EQ(IoState::kPending, read_state_); read_state_ = IoState::kIdle; } bool SequencedSocketData::AllWriteDataConsumed() const { return helper_.AllWriteDataConsumed(); } void SequencedSocketData::ExpectAllReadDataConsumed() const { helper_.ExpectAllReadDataConsumed(printer_.get()); } void SequencedSocketData::ExpectAllWriteDataConsumed() const { helper_.ExpectAllWriteDataConsumed(printer_.get()); } bool SequencedSocketData::IsIdle() const { // If |busy_before_sync_reads_| is not set, always considered idle. If // no reads left, or the next operation is a write, also consider it idle. if (!busy_before_sync_reads_ || helper_.AllReadDataConsumed() || helper_.PeekRead().sequence_number != sequence_number_) { return true; } // If the next operation is synchronous read, treat the socket as not idle. if (helper_.PeekRead().mode == SYNCHRONOUS) return false; return true; } bool SequencedSocketData::IsPaused() const { // Both states should not be paused. DCHECK(read_state_ != IoState::kPaused || write_state_ != IoState::kPaused); return write_state_ == IoState::kPaused || read_state_ == IoState::kPaused; } void SequencedSocketData::Resume() { if (!IsPaused()) { ADD_FAILURE() << "Unable to Resume when not paused."; return; } sequence_number_++; if (read_state_ == IoState::kPaused) { read_state_ = IoState::kPending; helper_.AdvanceRead(); } else { // write_state_ == IoState::kPaused write_state_ = IoState::kPending; helper_.AdvanceWrite(); } if (!helper_.AllWriteDataConsumed() && helper_.PeekWrite().sequence_number == sequence_number_) { // The next event hasn't even started yet. Pausing isn't really needed in // that case, but may as well support it. if (write_state_ != IoState::kPending) return; write_state_ = IoState::kCompleting; OnWriteComplete(); return; } CHECK(!helper_.AllReadDataConsumed()); // The next event hasn't even started yet. Pausing isn't really needed in // that case, but may as well support it. if (read_state_ != IoState::kPending) return; read_state_ = IoState::kCompleting; OnReadComplete(); } void SequencedSocketData::RunUntilPaused() { CHECK(!run_until_paused_run_loop_); if (IsPaused()) return; run_until_paused_run_loop_ = std::make_unique(); run_until_paused_run_loop_->Run(); run_until_paused_run_loop_.reset(); DCHECK(IsPaused()); } void SequencedSocketData::MaybePostReadCompleteTask() { NET_TRACE(1, " ****** ") << " current: " << sequence_number_; // Only trigger the next read to complete if there is already a read pending // which should complete at the current sequence number. if (read_state_ != IoState::kPending || helper_.PeekRead().sequence_number != sequence_number_) { return; } // If the result is ERR_IO_PENDING, then pause. if (helper_.PeekRead().result == ERR_IO_PENDING) { NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_; read_state_ = IoState::kPaused; if (run_until_paused_run_loop_) run_until_paused_run_loop_->Quit(); return; } NET_TRACE(1, " ****** ") << "Posting task to complete read: " << sequence_number_; base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete, weak_factory_.GetWeakPtr())); CHECK_NE(IoState::kCompleting, write_state_); read_state_ = IoState::kCompleting; } void SequencedSocketData::MaybePostWriteCompleteTask() { NET_TRACE(1, " ****** ") << " current: " << sequence_number_; // Only trigger the next write to complete if there is already a write pending // which should complete at the current sequence number. if (write_state_ != IoState::kPending || helper_.PeekWrite().sequence_number != sequence_number_) { return; } // If the result is ERR_IO_PENDING, then pause. if (helper_.PeekWrite().result == ERR_IO_PENDING) { NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_; write_state_ = IoState::kPaused; if (run_until_paused_run_loop_) run_until_paused_run_loop_->Quit(); return; } NET_TRACE(1, " ****** ") << "Posting task to complete write: " << sequence_number_; base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete, weak_factory_.GetWeakPtr())); CHECK_NE(IoState::kCompleting, read_state_); write_state_ = IoState::kCompleting; } void SequencedSocketData::Reset() { helper_.Reset(); sequence_number_ = 0; read_state_ = IoState::kIdle; write_state_ = IoState::kIdle; weak_factory_.InvalidateWeakPtrs(); } void SequencedSocketData::OnReadComplete() { CHECK_EQ(IoState::kCompleting, read_state_); NET_TRACE(1, " *** ") << "Completing read for: " << sequence_number_; MockRead data = helper_.AdvanceRead(); DCHECK_EQ(sequence_number_, data.sequence_number); sequence_number_++; read_state_ = IoState::kIdle; // The result of this read completing might trigger the completion // of a pending write. If so, post a task to complete the write later. // Since the socket may call back into the SequencedSocketData // from socket()->OnReadComplete(), trigger the write task to be posted // before calling that. MaybePostWriteCompleteTask(); if (!socket()) { NET_TRACE(1, " *** ") << "No socket available to complete read"; return; } NET_TRACE(1, " *** ") << "Completing socket read for: " << data.sequence_number; DumpMockReadWrite(data); socket()->OnReadComplete(data); NET_TRACE(1, " *** ") << "Done"; } void SequencedSocketData::OnWriteComplete() { CHECK_EQ(IoState::kCompleting, write_state_); NET_TRACE(1, " *** ") << " Completing write for: " << sequence_number_; const MockWrite& data = helper_.AdvanceWrite(); DCHECK_EQ(sequence_number_, data.sequence_number); sequence_number_++; write_state_ = IoState::kIdle; int rv = data.result == OK ? data.data_len : data.result; // The result of this write completing might trigger the completion // of a pending read. If so, post a task to complete the read later. // Since the socket may call back into the SequencedSocketData // from socket()->OnWriteComplete(), trigger the write task to be posted // before calling that. MaybePostReadCompleteTask(); if (!socket()) { NET_TRACE(1, " *** ") << "No socket available to complete write"; return; } NET_TRACE(1, " *** ") << " Completing socket write for: " << data.sequence_number; socket()->OnWriteComplete(rv); NET_TRACE(1, " *** ") << "Done"; } SequencedSocketData::~SequencedSocketData() = default; MockClientSocketFactory::MockClientSocketFactory() = default; MockClientSocketFactory::~MockClientSocketFactory() = default; void MockClientSocketFactory::AddSocketDataProvider(SocketDataProvider* data) { mock_data_.Add(data); } void MockClientSocketFactory::AddTcpSocketDataProvider( SocketDataProvider* data) { mock_tcp_data_.Add(data); } void MockClientSocketFactory::AddSSLSocketDataProvider( SSLSocketDataProvider* data) { mock_ssl_data_.Add(data); } void MockClientSocketFactory::ResetNextMockIndexes() { mock_data_.ResetNextIndex(); mock_ssl_data_.ResetNextIndex(); } std::unique_ptr MockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, NetLog* net_log, const NetLogSource& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); auto socket = std::make_unique(data_provider, net_log); if (bind_type == DatagramSocket::RANDOM_BIND) socket->set_source_port(static_cast(base::RandInt(1025, 65535))); udp_client_socket_ports_.push_back(socket->source_port()); return std::move(socket); } std::unique_ptr MockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, std::unique_ptr socket_performance_watcher, NetworkQualityEstimator* network_quality_estimator, NetLog* net_log, const NetLogSource& source) { SocketDataProvider* data_provider = mock_tcp_data_.GetNextWithoutAsserting(); if (!data_provider) data_provider = mock_data_.GetNext(); auto socket = std::make_unique(addresses, net_log, data_provider); if (enable_read_if_ready_) socket->set_enable_read_if_ready(enable_read_if_ready_); return std::move(socket); } std::unique_ptr MockClientSocketFactory::CreateSSLClientSocket( SSLClientContext* context, std::unique_ptr stream_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config) { SSLSocketDataProvider* next_ssl_data = mock_ssl_data_.GetNext(); if (next_ssl_data->next_protos_expected_in_ssl_config.has_value()) { EXPECT_TRUE(base::ranges::equal( next_ssl_data->next_protos_expected_in_ssl_config.value(), ssl_config.alpn_protos)); } if (next_ssl_data->expected_application_settings) { EXPECT_EQ(*next_ssl_data->expected_application_settings, ssl_config.application_settings); } // The protocol version used is a combination of the per-socket SSLConfig and // the SSLConfigService. EXPECT_EQ( next_ssl_data->expected_ssl_version_min, ssl_config.version_min_override.value_or(context->config().version_min)); EXPECT_EQ( next_ssl_data->expected_ssl_version_max, ssl_config.version_max_override.value_or(context->config().version_max)); if (next_ssl_data->expected_early_data_enabled) { EXPECT_EQ(*next_ssl_data->expected_early_data_enabled, ssl_config.early_data_enabled); } if (next_ssl_data->expected_send_client_cert) { // Client certificate preferences come from |context|. scoped_refptr client_cert; scoped_refptr client_private_key; bool send_client_cert = context->GetClientCertificate( host_and_port, &client_cert, &client_private_key); EXPECT_EQ(*next_ssl_data->expected_send_client_cert, send_client_cert); // Note |send_client_cert| may be true while |client_cert| is null if the // socket is configured to continue without a certificate, as opposed to // surfacing the certificate challenge. EXPECT_EQ(!!next_ssl_data->expected_client_cert, !!client_cert); if (next_ssl_data->expected_client_cert && client_cert) { EXPECT_TRUE(next_ssl_data->expected_client_cert->EqualsIncludingChain( client_cert.get())); } } if (next_ssl_data->expected_host_and_port) { EXPECT_EQ(*next_ssl_data->expected_host_and_port, host_and_port); } if (next_ssl_data->expected_ignore_certificate_errors) { EXPECT_EQ(*next_ssl_data->expected_ignore_certificate_errors, ssl_config.ignore_certificate_errors); } if (next_ssl_data->expected_network_anonymization_key) { EXPECT_EQ(*next_ssl_data->expected_network_anonymization_key, ssl_config.network_anonymization_key); } if (next_ssl_data->expected_ech_config_list) { EXPECT_EQ(*next_ssl_data->expected_ech_config_list, ssl_config.ech_config_list); } return std::make_unique( std::move(stream_socket), host_and_port, ssl_config, next_ssl_data); } MockClientSocket::MockClientSocket(const NetLogWithSource& net_log) : net_log_(net_log) { local_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 123); peer_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 0); } int MockClientSocket::SetReceiveBufferSize(int32_t size) { return OK; } int MockClientSocket::SetSendBufferSize(int32_t size) { return OK; } int MockClientSocket::Bind(const net::IPEndPoint& local_addr) { local_addr_ = local_addr; return net::OK; } bool MockClientSocket::SetNoDelay(bool no_delay) { return true; } bool MockClientSocket::SetKeepAlive(bool enable, int delay) { return true; } void MockClientSocket::Disconnect() { connected_ = false; } bool MockClientSocket::IsConnected() const { return connected_; } bool MockClientSocket::IsConnectedAndIdle() const { return connected_; } int MockClientSocket::GetPeerAddress(IPEndPoint* address) const { if (!IsConnected()) return ERR_SOCKET_NOT_CONNECTED; *address = peer_addr_; return OK; } int MockClientSocket::GetLocalAddress(IPEndPoint* address) const { *address = local_addr_; return OK; } const NetLogWithSource& MockClientSocket::NetLog() const { return net_log_; } NextProto MockClientSocket::GetNegotiatedProtocol() const { return kProtoUnknown; } MockClientSocket::~MockClientSocket() = default; void MockClientSocket::RunCallbackAsync(CompletionOnceCallback callback, int result) { base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( FROM_HERE, base::BindOnce(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(), std::move(callback), result)); } void MockClientSocket::RunCallback(CompletionOnceCallback callback, int result) { std::move(callback).Run(result); } MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log, SocketDataProvider* data) : MockClientSocket(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)), addresses_(addresses), data_(data), read_data_(SYNCHRONOUS, ERR_UNEXPECTED) { DCHECK(data_); peer_addr_ = data->connect_data().peer_addr; data_->Initialize(this); if (data_->expected_addresses()) { EXPECT_EQ(*data_->expected_addresses(), addresses); } } MockTCPClientSocket::~MockTCPClientSocket() { if (data_) data_->DetachSocket(); } int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) { // If the buffer is already in use, a read is already in progress! DCHECK(!pending_read_buf_); // Use base::Unretained() is safe because MockClientSocket::RunCallbackAsync() // takes a weak ptr of the base class, MockClientSocket. int rv = ReadIfReadyImpl( buf, buf_len, base::BindOnce(&MockTCPClientSocket::RetryRead, base::Unretained(this))); if (rv == ERR_IO_PENDING) { DCHECK(callback); pending_read_buf_ = buf; pending_read_buf_len_ = buf_len; pending_read_callback_ = std::move(callback); } return rv; } int MockTCPClientSocket::ReadIfReady(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) { DCHECK(!pending_read_if_ready_callback_); if (!enable_read_if_ready_) return ERR_READ_IF_READY_NOT_IMPLEMENTED; return ReadIfReadyImpl(buf, buf_len, std::move(callback)); } int MockTCPClientSocket::CancelReadIfReady() { DCHECK(pending_read_if_ready_callback_); pending_read_if_ready_callback_.Reset(); data_->CancelPendingRead(); return OK; } int MockTCPClientSocket::Write( IOBuffer* buf, int buf_len, CompletionOnceCallback callback, const NetworkTrafficAnnotationTag& /* traffic_annotation */) { DCHECK(buf); DCHECK_GT(buf_len, 0); if (!connected_ || !data_) return ERR_UNEXPECTED; std::string data(buf->data(), buf_len); MockWriteResult write_result = data_->OnWrite(data); was_used_to_convey_data_ = true; if (write_result.result == ERR_CONNECTION_CLOSED) { // This MockWrite is just a marker to instruct us to set // peer_closed_connection_. peer_closed_connection_ = true; } // ERR_IO_PENDING is a signal that the socket data will call back // asynchronously later. if (write_result.result == ERR_IO_PENDING) { pending_write_callback_ = std::move(callback); return ERR_IO_PENDING; } if (write_result.mode == ASYNC) { RunCallbackAsync(std::move(callback), write_result.result); return ERR_IO_PENDING; } return write_result.result; } int MockTCPClientSocket::SetReceiveBufferSize(int32_t size) { if (!connected_) return net::ERR_UNEXPECTED; data_->set_receive_buffer_size(size); return data_->set_receive_buffer_size_result(); } int MockTCPClientSocket::SetSendBufferSize(int32_t size) { if (!connected_) return net::ERR_UNEXPECTED; data_->set_send_buffer_size(size); return data_->set_send_buffer_size_result(); } bool MockTCPClientSocket::SetNoDelay(bool no_delay) { if (!connected_) return false; data_->set_no_delay(no_delay); return data_->set_no_delay_result(); } bool MockTCPClientSocket::SetKeepAlive(bool enable, int delay) { if (!connected_) return false; data_->set_keep_alive(enable, delay); return data_->set_keep_alive_result(); } void MockTCPClientSocket::SetBeforeConnectCallback( const BeforeConnectCallback& before_connect_callback) { DCHECK(!before_connect_callback_); DCHECK(!connected_); before_connect_callback_ = before_connect_callback; } int MockTCPClientSocket::Connect(CompletionOnceCallback callback) { if (!data_) return ERR_UNEXPECTED; if (connected_) return OK; // Setting socket options fails if not connected, so need to set this before // calling |before_connect_callback_|. connected_ = true; if (before_connect_callback_) { for (size_t index = 0; index < addresses_.size(); index++) { int result = before_connect_callback_.Run(); if (data_->connect_data().first_attempt_fails && index == 0) { continue; } DCHECK_NE(result, ERR_IO_PENDING); if (result != net::OK) { connected_ = false; return result; } break; } } peer_closed_connection_ = false; int result = data_->connect_data().result; IoMode mode = data_->connect_data().mode; if (mode == SYNCHRONOUS) return result; DCHECK(callback); if (result == ERR_IO_PENDING) pending_connect_callback_ = std::move(callback); else RunCallbackAsync(std::move(callback), result); return ERR_IO_PENDING; } void MockTCPClientSocket::Disconnect() { MockClientSocket::Disconnect(); pending_connect_callback_.Reset(); pending_read_callback_.Reset(); } bool MockTCPClientSocket::IsConnected() const { if (!data_) return false; return connected_ && !peer_closed_connection_; } bool MockTCPClientSocket::IsConnectedAndIdle() const { if (!data_) return false; return IsConnected() && data_->IsIdle(); } int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const { if (addresses_.empty()) return MockClientSocket::GetPeerAddress(address); if (data_->connect_data().first_attempt_fails) { DCHECK_GE(addresses_.size(), 2U); *address = addresses_[1]; } else { *address = addresses_[0]; } return OK; } bool MockTCPClientSocket::WasEverUsed() const { return was_used_to_convey_data_; } bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { return false; } void MockTCPClientSocket::OnReadComplete(const MockRead& data) { // If |data_| has been destroyed, safest to just do nothing. if (!data_) return; // There must be a read pending. DCHECK(pending_read_if_ready_callback_); // You can't complete a read with another ERR_IO_PENDING status code. DCHECK_NE(ERR_IO_PENDING, data.result); // Since we've been waiting for data, need_read_data_ should be true. DCHECK(need_read_data_); read_data_ = data; need_read_data_ = false; // The caller is simulating that this IO completes right now. Don't // let CompleteRead() schedule a callback. read_data_.mode = SYNCHRONOUS; RunCallback(std::move(pending_read_if_ready_callback_), read_data_.result > 0 ? OK : read_data_.result); } void MockTCPClientSocket::OnWriteComplete(int rv) { // If |data_| has been destroyed, safest to just do nothing. if (!data_) return; // There must be a read pending. DCHECK(!pending_write_callback_.is_null()); RunCallback(std::move(pending_write_callback_), rv); } void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) { // If |data_| has been destroyed, safest to just do nothing. if (!data_) return; RunCallback(std::move(pending_connect_callback_), data.result); } void MockTCPClientSocket::OnDataProviderDestroyed() { data_ = nullptr; } void MockTCPClientSocket::RetryRead(int rv) { DCHECK(pending_read_callback_); DCHECK(pending_read_buf_.get()); DCHECK_LT(0, pending_read_buf_len_); if (rv == OK) { rv = ReadIfReadyImpl(pending_read_buf_.get(), pending_read_buf_len_, base::BindOnce(&MockTCPClientSocket::RetryRead, base::Unretained(this))); if (rv == ERR_IO_PENDING) return; } pending_read_buf_ = nullptr; pending_read_buf_len_ = 0; RunCallback(std::move(pending_read_callback_), rv); } int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) { if (!connected_ || !data_) return ERR_UNEXPECTED; DCHECK(!pending_read_if_ready_callback_); if (need_read_data_) { read_data_ = data_->OnRead(); if (read_data_.result == ERR_CONNECTION_CLOSED) { // This MockRead is just a marker to instruct us to set // peer_closed_connection_. peer_closed_connection_ = true; } if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { // This MockRead is just a marker to instruct us to set // peer_closed_connection_. Skip it and get the next one. read_data_ = data_->OnRead(); peer_closed_connection_ = true; } // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility // to complete the async IO manually later (via OnReadComplete). if (read_data_.result == ERR_IO_PENDING) { // We need to be using async IO in this case. DCHECK(!callback.is_null()); pending_read_if_ready_callback_ = std::move(callback); return ERR_IO_PENDING; } need_read_data_ = false; } int result = read_data_.result; DCHECK_NE(ERR_IO_PENDING, result); if (read_data_.mode == ASYNC) { DCHECK(!callback.is_null()); read_data_.mode = SYNCHRONOUS; pending_read_if_ready_callback_ = std::move(callback); // base::Unretained() is safe here because RunCallbackAsync will wrap it // with a callback associated with a weak ptr. RunCallbackAsync( base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback, base::Unretained(this)), result); return ERR_IO_PENDING; } was_used_to_convey_data_ = true; if (read_data_.data) { if (read_data_.data_len - read_offset_ > 0) { result = std::min(buf_len, read_data_.data_len - read_offset_); memcpy(buf->data(), read_data_.data + read_offset_, result); read_offset_ += result; if (read_offset_ == read_data_.data_len) { need_read_data_ = true; read_offset_ = 0; } } else { result = 0; // EOF } } return result; } void MockTCPClientSocket::RunReadIfReadyCallback(int result) { // If ReadIfReady is already canceled, do nothing. if (!pending_read_if_ready_callback_) return; std::move(pending_read_if_ready_callback_).Run(result); } // static void MockSSLClientSocket::ConnectCallback( MockSSLClientSocket* ssl_client_socket, CompletionOnceCallback callback, int rv) { if (rv == OK) ssl_client_socket->connected_ = true; std::move(callback).Run(rv); } MockSSLClientSocket::MockSSLClientSocket( std::unique_ptr stream_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLSocketDataProvider* data) : net_log_(stream_socket->NetLog()), stream_socket_(std::move(stream_socket)), data_(data) { DCHECK(data_); peer_addr_ = data->connect.peer_addr; } MockSSLClientSocket::~MockSSLClientSocket() { Disconnect(); } int MockSSLClientSocket::Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) { return stream_socket_->Read(buf, buf_len, std::move(callback)); } int MockSSLClientSocket::ReadIfReady(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) { return stream_socket_->ReadIfReady(buf, buf_len, std::move(callback)); } int MockSSLClientSocket::Write( IOBuffer* buf, int buf_len, CompletionOnceCallback callback, const NetworkTrafficAnnotationTag& traffic_annotation) { if (!data_->is_confirm_data_consumed) data_->write_called_before_confirm = true; return stream_socket_->Write(buf, buf_len, std::move(callback), traffic_annotation); } int MockSSLClientSocket::CancelReadIfReady() { return stream_socket_->CancelReadIfReady(); } int MockSSLClientSocket::Connect(CompletionOnceCallback callback) { DCHECK(stream_socket_->IsConnected()); data_->is_connect_data_consumed = true; if (data_->connect.result == OK) connected_ = true; RunClosureIfNonNull(std::move(data_->connect_callback)); if (data_->connect.mode == ASYNC) { RunCallbackAsync(std::move(callback), data_->connect.result); return ERR_IO_PENDING; } return data_->connect.result; } void MockSSLClientSocket::Disconnect() { if (stream_socket_ != nullptr) stream_socket_->Disconnect(); } void MockSSLClientSocket::RunConfirmHandshakeCallback( CompletionOnceCallback callback, int result) { DCHECK(in_confirm_handshake_); in_confirm_handshake_ = false; data_->is_confirm_data_consumed = true; std::move(callback).Run(result); } int MockSSLClientSocket::ConfirmHandshake(CompletionOnceCallback callback) { DCHECK(stream_socket_->IsConnected()); DCHECK(!in_confirm_handshake_); if (data_->is_confirm_data_consumed) return data_->confirm.result; RunClosureIfNonNull(std::move(data_->confirm_callback)); if (data_->confirm.mode == ASYNC) { in_confirm_handshake_ = true; RunCallbackAsync( base::BindOnce(&MockSSLClientSocket::RunConfirmHandshakeCallback, base::Unretained(this), std::move(callback)), data_->confirm.result); return ERR_IO_PENDING; } data_->is_confirm_data_consumed = true; if (data_->confirm.result == ERR_IO_PENDING) { // `MockConfirm(SYNCHRONOUS, ERR_IO_PENDING)` means `ConfirmHandshake()` // never completes. in_confirm_handshake_ = true; } return data_->confirm.result; } bool MockSSLClientSocket::IsConnected() const { return stream_socket_->IsConnected(); } bool MockSSLClientSocket::IsConnectedAndIdle() const { return stream_socket_->IsConnectedAndIdle(); } bool MockSSLClientSocket::WasEverUsed() const { return stream_socket_->WasEverUsed(); } int MockSSLClientSocket::GetLocalAddress(IPEndPoint* address) const { *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123); return OK; } int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const { return stream_socket_->GetPeerAddress(address); } NextProto MockSSLClientSocket::GetNegotiatedProtocol() const { return data_->next_proto; } std::optional MockSSLClientSocket::GetPeerApplicationSettings() const { return data_->peer_application_settings; } bool MockSSLClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) { *requested_ssl_info = data_->ssl_info; return true; } void MockSSLClientSocket::ApplySocketTag(const SocketTag& tag) { return stream_socket_->ApplySocketTag(tag); } const NetLogWithSource& MockSSLClientSocket::NetLog() const { return net_log_; } int64_t MockSSLClientSocket::GetTotalReceivedBytes() const { NOTIMPLEMENTED(); return 0; } int64_t MockClientSocket::GetTotalReceivedBytes() const { NOTIMPLEMENTED(); return 0; } int MockSSLClientSocket::SetReceiveBufferSize(int32_t size) { return OK; } int MockSSLClientSocket::SetSendBufferSize(int32_t size) { return OK; } void MockSSLClientSocket::GetSSLCertRequestInfo( SSLCertRequestInfo* cert_request_info) const { DCHECK(cert_request_info); if (data_->cert_request_info) { cert_request_info->host_and_port = data_->cert_request_info->host_and_port; cert_request_info->is_proxy = data_->cert_request_info->is_proxy; cert_request_info->cert_authorities = data_->cert_request_info->cert_authorities; cert_request_info->signature_algorithms = data_->cert_request_info->signature_algorithms; } else { cert_request_info->Reset(); } } int MockSSLClientSocket::ExportKeyingMaterial(std::string_view label, bool has_context, std::string_view context, unsigned char* out, unsigned int outlen) { memset(out, 'A', outlen); return OK; } std::vector MockSSLClientSocket::GetECHRetryConfigs() { return data_->ech_retry_configs; } void MockSSLClientSocket::RunCallbackAsync(CompletionOnceCallback callback, int result) { base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( FROM_HERE, base::BindOnce(&MockSSLClientSocket::RunCallback, weak_factory_.GetWeakPtr(), std::move(callback), result)); } void MockSSLClientSocket::RunCallback(CompletionOnceCallback callback, int result) { std::move(callback).Run(result); } void MockSSLClientSocket::OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); } void MockSSLClientSocket::OnWriteComplete(int rv) { NOTIMPLEMENTED(); } void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) { NOTIMPLEMENTED(); } MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log) : data_(data), read_data_(SYNCHRONOUS, ERR_UNEXPECTED), source_host_(IPAddress(192, 0, 2, 33)), net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)) { if (data_) { data_->Initialize(this); peer_addr_ = data->connect_data().peer_addr; } } MockUDPClientSocket::~MockUDPClientSocket() { if (data_) data_->DetachSocket(); } int MockUDPClientSocket::Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) { DCHECK(callback); if (!connected_ || !data_) return ERR_UNEXPECTED; data_transferred_ = true; // If the buffer is already in use, a read is already in progress! DCHECK(!pending_read_buf_); // Store our async IO data. pending_read_buf_ = buf; pending_read_buf_len_ = buf_len; pending_read_callback_ = std::move(callback); if (need_read_data_) { read_data_ = data_->OnRead(); last_tos_ = read_data_.tos; // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility // to complete the async IO manually later (via OnReadComplete). if (read_data_.result == ERR_IO_PENDING) { // We need to be using async IO in this case. DCHECK(!pending_read_callback_.is_null()); return ERR_IO_PENDING; } need_read_data_ = false; } return CompleteRead(); } int MockUDPClientSocket::Write( IOBuffer* buf, int buf_len, CompletionOnceCallback callback, const NetworkTrafficAnnotationTag& /* traffic_annotation */) { DCHECK(buf); DCHECK_GT(buf_len, 0); DCHECK(callback); if (!connected_ || !data_) return ERR_UNEXPECTED; data_transferred_ = true; std::string data(buf->data(), buf_len); MockWriteResult write_result = data_->OnWrite(data); // ERR_IO_PENDING is a signal that the socket data will call back // asynchronously. if (write_result.result == ERR_IO_PENDING) { pending_write_callback_ = std::move(callback); return ERR_IO_PENDING; } if (write_result.mode == ASYNC) { RunCallbackAsync(std::move(callback), write_result.result); return ERR_IO_PENDING; } return write_result.result; } int MockUDPClientSocket::SetReceiveBufferSize(int32_t size) { return OK; } int MockUDPClientSocket::SetSendBufferSize(int32_t size) { return OK; } int MockUDPClientSocket::SetDoNotFragment() { return OK; } int MockUDPClientSocket::SetRecvTos() { return OK; } int MockUDPClientSocket::SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) { return OK; } void MockUDPClientSocket::Close() { connected_ = false; } int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const { if (!data_) return ERR_UNEXPECTED; *address = peer_addr_; return OK; } int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const { *address = IPEndPoint(source_host_, source_port_); return OK; } void MockUDPClientSocket::UseNonBlockingIO() {} int MockUDPClientSocket::SetMulticastInterface(uint32_t interface_index) { return OK; } const NetLogWithSource& MockUDPClientSocket::NetLog() const { return net_log_; } int MockUDPClientSocket::Connect(const IPEndPoint& address) { if (!data_) return ERR_UNEXPECTED; DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING); connected_ = true; peer_addr_ = address; return data_->connect_data().result; } int MockUDPClientSocket::ConnectUsingNetwork(handles::NetworkHandle network, const IPEndPoint& address) { DCHECK(!connected_); if (!data_) return ERR_UNEXPECTED; DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING); network_ = network; connected_ = true; peer_addr_ = address; return data_->connect_data().result; } int MockUDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) { DCHECK(!connected_); if (!data_) return ERR_UNEXPECTED; DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING); network_ = kDefaultNetworkForTests; connected_ = true; peer_addr_ = address; return data_->connect_data().result; } int MockUDPClientSocket::ConnectAsync(const IPEndPoint& address, CompletionOnceCallback callback) { DCHECK(callback); if (!data_) { return ERR_UNEXPECTED; } connected_ = true; peer_addr_ = address; int result = data_->connect_data().result; IoMode mode = data_->connect_data().mode; if (mode == SYNCHRONOUS) { return result; } RunCallbackAsync(std::move(callback), result); return ERR_IO_PENDING; } int MockUDPClientSocket::ConnectUsingNetworkAsync( handles::NetworkHandle network, const IPEndPoint& address, CompletionOnceCallback callback) { DCHECK(callback); DCHECK(!connected_); if (!data_) return ERR_UNEXPECTED; network_ = network; connected_ = true; peer_addr_ = address; int result = data_->connect_data().result; IoMode mode = data_->connect_data().mode; if (mode == SYNCHRONOUS) { return result; } RunCallbackAsync(std::move(callback), result); return ERR_IO_PENDING; } int MockUDPClientSocket::ConnectUsingDefaultNetworkAsync( const IPEndPoint& address, CompletionOnceCallback callback) { DCHECK(!connected_); if (!data_) return ERR_UNEXPECTED; network_ = kDefaultNetworkForTests; connected_ = true; peer_addr_ = address; int result = data_->connect_data().result; IoMode mode = data_->connect_data().mode; if (mode == SYNCHRONOUS) { return result; } RunCallbackAsync(std::move(callback), result); return ERR_IO_PENDING; } handles::NetworkHandle MockUDPClientSocket::GetBoundNetwork() const { return network_; } void MockUDPClientSocket::ApplySocketTag(const SocketTag& tag) { tagged_before_data_transferred_ &= !data_transferred_ || tag == tag_; tag_ = tag; } DscpAndEcn MockUDPClientSocket::GetLastTos() const { return TosToDscpAndEcn(last_tos_); } void MockUDPClientSocket::OnReadComplete(const MockRead& data) { if (!data_) return; // There must be a read pending. DCHECK(pending_read_buf_.get()); DCHECK(pending_read_callback_); // You can't complete a read with another ERR_IO_PENDING status code. DCHECK_NE(ERR_IO_PENDING, data.result); // Since we've been waiting for data, need_read_data_ should be true. DCHECK(need_read_data_); read_data_ = data; last_tos_ = data.tos; need_read_data_ = false; // The caller is simulating that this IO completes right now. Don't // let CompleteRead() schedule a callback. read_data_.mode = SYNCHRONOUS; CompletionOnceCallback callback = std::move(pending_read_callback_); int rv = CompleteRead(); RunCallback(std::move(callback), rv); } void MockUDPClientSocket::OnWriteComplete(int rv) { if (!data_) return; // There must be a read pending. DCHECK(!pending_write_callback_.is_null()); RunCallback(std::move(pending_write_callback_), rv); } void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) { NOTIMPLEMENTED(); } void MockUDPClientSocket::OnDataProviderDestroyed() { data_ = nullptr; } int MockUDPClientSocket::CompleteRead() { DCHECK(pending_read_buf_.get()); DCHECK(pending_read_buf_len_ > 0); // Save the pending async IO data and reset our |pending_| state. scoped_refptr buf = pending_read_buf_; int buf_len = pending_read_buf_len_; CompletionOnceCallback callback = std::move(pending_read_callback_); pending_read_buf_ = nullptr; pending_read_buf_len_ = 0; int result = read_data_.result; DCHECK(result != ERR_IO_PENDING); if (read_data_.data) { if (read_data_.data_len - read_offset_ > 0) { result = std::min(buf_len, read_data_.data_len - read_offset_); memcpy(buf->data(), read_data_.data + read_offset_, result); read_offset_ += result; if (read_offset_ == read_data_.data_len) { need_read_data_ = true; read_offset_ = 0; } } else { result = 0; // EOF } } if (read_data_.mode == ASYNC) { DCHECK(!callback.is_null()); RunCallbackAsync(std::move(callback), result); return ERR_IO_PENDING; } return result; } void MockUDPClientSocket::RunCallbackAsync(CompletionOnceCallback callback, int result) { base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( FROM_HERE, base::BindOnce(&MockUDPClientSocket::RunCallback, weak_factory_.GetWeakPtr(), std::move(callback), result)); } void MockUDPClientSocket::RunCallback(CompletionOnceCallback callback, int result) { std::move(callback).Run(result); } TestSocketRequest::TestSocketRequest( std::vector>* request_order, size_t* completion_count) : request_order_(request_order), completion_count_(completion_count) { DCHECK(request_order); DCHECK(completion_count); } TestSocketRequest::~TestSocketRequest() = default; void TestSocketRequest::OnComplete(int result) { SetResult(result); (*completion_count_)++; request_order_->push_back(this); } // static const int ClientSocketPoolTest::kIndexOutOfBounds = -1; // static const int ClientSocketPoolTest::kRequestNotFound = -2; ClientSocketPoolTest::ClientSocketPoolTest() = default; ClientSocketPoolTest::~ClientSocketPoolTest() = default; int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const { index--; if (index >= requests_.size()) return kIndexOutOfBounds; for (size_t i = 0; i < request_order_.size(); i++) if (requests_[index].get() == request_order_[i]) return i + 1; return kRequestNotFound; } bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) { for (std::unique_ptr& it : requests_) { if (it->handle()->is_initialized()) { if (keep_alive == NO_KEEP_ALIVE) it->handle()->socket()->Disconnect(); it->handle()->Reset(); base::RunLoop().RunUntilIdle(); return true; } } return false; } void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { bool released_one; do { released_one = ReleaseOneConnection(keep_alive); } while (released_one); } MockTransportClientSocketPool::MockConnectJob::MockConnectJob( std::unique_ptr socket, ClientSocketHandle* handle, const SocketTag& socket_tag, CompletionOnceCallback callback, RequestPriority priority) : socket_(std::move(socket)), handle_(handle), socket_tag_(socket_tag), user_callback_(std::move(callback)), priority_(priority) {} MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() = default; int MockTransportClientSocketPool::MockConnectJob::Connect() { socket_->ApplySocketTag(socket_tag_); int rv = socket_->Connect( base::BindOnce(&MockConnectJob::OnConnect, base::Unretained(this))); if (rv != ERR_IO_PENDING) { user_callback_.Reset(); OnConnect(rv); } return rv; } bool MockTransportClientSocketPool::MockConnectJob::CancelHandle( const ClientSocketHandle* handle) { if (handle != handle_) return false; socket_.reset(); handle_ = nullptr; user_callback_.Reset(); return true; } void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { if (!socket_.get()) return; if (rv == OK) { handle_->SetSocket(std::move(socket_)); // Needed for socket pool tests that layer other sockets on top of mock // sockets. LoadTimingInfo::ConnectTiming connect_timing; base::TimeTicks now = base::TimeTicks::Now(); connect_timing.domain_lookup_start = now; connect_timing.domain_lookup_end = now; connect_timing.connect_start = now; connect_timing.connect_end = now; handle_->set_connect_timing(connect_timing); } else { socket_.reset(); // Needed to test copying of ConnectionAttempts in SSL ConnectJob. ConnectionAttempts attempts; attempts.push_back(ConnectionAttempt(IPEndPoint(), rv)); handle_->set_connection_attempts(attempts); } handle_ = nullptr; if (!user_callback_.is_null()) { std::move(user_callback_).Run(rv); } } MockTransportClientSocketPool::MockTransportClientSocketPool( int max_sockets, int max_sockets_per_group, const CommonConnectJobParams* common_connect_job_params) : TransportClientSocketPool( max_sockets, max_sockets_per_group, base::Seconds(10) /* unused_idle_socket_timeout */, ProxyChain::Direct(), false /* is_for_websockets */, common_connect_job_params), client_socket_factory_(common_connect_job_params->client_socket_factory) { } MockTransportClientSocketPool::~MockTransportClientSocketPool() = default; int MockTransportClientSocketPool::RequestSocket( const ClientSocketPool::GroupId& group_id, scoped_refptr socket_params, const std::optional& proxy_annotation_tag, RequestPriority priority, const SocketTag& socket_tag, RespectLimits respect_limits, ClientSocketHandle* handle, CompletionOnceCallback callback, const ProxyAuthCallback& on_auth_callback, const NetLogWithSource& net_log) { last_request_priority_ = priority; std::unique_ptr socket = client_socket_factory_->CreateTransportClientSocket( AddressList(), nullptr, nullptr, net_log.net_log(), NetLogSource()); auto job = std::make_unique( std::move(socket), handle, socket_tag, std::move(callback), priority); auto* job_ptr = job.get(); job_list_.push_back(std::move(job)); handle->set_group_generation(1); return job_ptr->Connect(); } void MockTransportClientSocketPool::SetPriority( const ClientSocketPool::GroupId& group_id, ClientSocketHandle* handle, RequestPriority priority) { for (auto& job : job_list_) { if (job->handle() == handle) { job->set_priority(priority); return; } } NOTREACHED(); } void MockTransportClientSocketPool::CancelRequest( const ClientSocketPool::GroupId& group_id, ClientSocketHandle* handle, bool cancel_connect_job) { for (std::unique_ptr& it : job_list_) { if (it->CancelHandle(handle)) { cancel_count_++; break; } } } void MockTransportClientSocketPool::ReleaseSocket( const ClientSocketPool::GroupId& group_id, std::unique_ptr socket, int64_t generation) { EXPECT_EQ(1, generation); release_count_++; } WrappedStreamSocket::WrappedStreamSocket( std::unique_ptr transport) : transport_(std::move(transport)) {} WrappedStreamSocket::~WrappedStreamSocket() = default; int WrappedStreamSocket::Bind(const net::IPEndPoint& local_addr) { NOTREACHED(); return ERR_FAILED; } int WrappedStreamSocket::Connect(CompletionOnceCallback callback) { return transport_->Connect(std::move(callback)); } void WrappedStreamSocket::Disconnect() { transport_->Disconnect(); } bool WrappedStreamSocket::IsConnected() const { return transport_->IsConnected(); } bool WrappedStreamSocket::IsConnectedAndIdle() const { return transport_->IsConnectedAndIdle(); } int WrappedStreamSocket::GetPeerAddress(IPEndPoint* address) const { return transport_->GetPeerAddress(address); } int WrappedStreamSocket::GetLocalAddress(IPEndPoint* address) const { return transport_->GetLocalAddress(address); } const NetLogWithSource& WrappedStreamSocket::NetLog() const { return transport_->NetLog(); } bool WrappedStreamSocket::WasEverUsed() const { return transport_->WasEverUsed(); } NextProto WrappedStreamSocket::GetNegotiatedProtocol() const { return transport_->GetNegotiatedProtocol(); } bool WrappedStreamSocket::GetSSLInfo(SSLInfo* ssl_info) { return transport_->GetSSLInfo(ssl_info); } int64_t WrappedStreamSocket::GetTotalReceivedBytes() const { return transport_->GetTotalReceivedBytes(); } void WrappedStreamSocket::ApplySocketTag(const SocketTag& tag) { transport_->ApplySocketTag(tag); } int WrappedStreamSocket::Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) { return transport_->Read(buf, buf_len, std::move(callback)); } int WrappedStreamSocket::ReadIfReady(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) { return transport_->ReadIfReady(buf, buf_len, std::move((callback))); } int WrappedStreamSocket::Write( IOBuffer* buf, int buf_len, CompletionOnceCallback callback, const NetworkTrafficAnnotationTag& traffic_annotation) { return transport_->Write(buf, buf_len, std::move(callback), TRAFFIC_ANNOTATION_FOR_TESTS); } int WrappedStreamSocket::SetReceiveBufferSize(int32_t size) { return transport_->SetReceiveBufferSize(size); } int WrappedStreamSocket::SetSendBufferSize(int32_t size) { return transport_->SetSendBufferSize(size); } int MockTaggingStreamSocket::Connect(CompletionOnceCallback callback) { connected_ = true; return WrappedStreamSocket::Connect(std::move(callback)); } void MockTaggingStreamSocket::ApplySocketTag(const SocketTag& tag) { tagged_before_connected_ &= !connected_ || tag == tag_; tag_ = tag; transport_->ApplySocketTag(tag); } std::unique_ptr MockTaggingClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, std::unique_ptr socket_performance_watcher, NetworkQualityEstimator* network_quality_estimator, NetLog* net_log, const NetLogSource& source) { auto socket = std::make_unique( MockClientSocketFactory::CreateTransportClientSocket( addresses, std::move(socket_performance_watcher), network_quality_estimator, net_log, source)); tcp_socket_ = socket.get(); return std::move(socket); } std::unique_ptr MockTaggingClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, NetLog* net_log, const NetLogSource& source) { std::unique_ptr socket( MockClientSocketFactory::CreateDatagramClientSocket(bind_type, net_log, source)); udp_socket_ = static_cast(socket.get()); return socket; } const char kSOCKS4TestHost[] = "127.0.0.1"; const int kSOCKS4TestPort = 80; const char kSOCKS4OkRequestLocalHostPort80[] = {0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0}; const int kSOCKS4OkRequestLocalHostPort80Length = std::size(kSOCKS4OkRequestLocalHostPort80); const char kSOCKS4OkReply[] = {0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0}; const int kSOCKS4OkReplyLength = std::size(kSOCKS4OkReply); const char kSOCKS5TestHost[] = "host"; const int kSOCKS5TestPort = 80; const char kSOCKS5GreetRequest[] = {0x05, 0x01, 0x00}; const int kSOCKS5GreetRequestLength = std::size(kSOCKS5GreetRequest); const char kSOCKS5GreetResponse[] = {0x05, 0x00}; const int kSOCKS5GreetResponseLength = std::size(kSOCKS5GreetResponse); const char kSOCKS5OkRequest[] = {0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50}; const int kSOCKS5OkRequestLength = std::size(kSOCKS5OkRequest); const char kSOCKS5OkResponse[] = {0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50}; const int kSOCKS5OkResponseLength = std::size(kSOCKS5OkResponse); int64_t CountReadBytes(base::span reads) { int64_t total = 0; for (const MockRead& read : reads) total += read.data_len; return total; } int64_t CountWriteBytes(base::span writes) { int64_t total = 0; for (const MockWrite& write : writes) total += write.data_len; return total; } #if BUILDFLAG(IS_ANDROID) bool CanGetTaggedBytes() { // In Android P, /proc/net/xt_qtaguid/stats is no longer guaranteed to be // present, and has been replaced with eBPF Traffic Monitoring in netd. See: // https://source.android.com/devices/tech/datausage/ebpf-traffic-monitor // // To read traffic statistics from netd, apps should use the API // NetworkStatsManager.queryDetailsForUidTag(). But this API does not provide // statistics for local traffic, only mobile and WiFi traffic, so it would not // work in tests that spin up a local server. So for now, GetTaggedBytes is // only supported on Android releases older than P. return base::android::BuildInfo::GetInstance()->sdk_int() < base::android::SDK_VERSION_P; } uint64_t GetTaggedBytes(int32_t expected_tag) { EXPECT_TRUE(CanGetTaggedBytes()); // To determine how many bytes the system saw with a particular tag read // the /proc/net/xt_qtaguid/stats file which contains the kernel's // dump of all the UIDs and their tags sent and received bytes. uint64_t bytes = 0; std::string contents; EXPECT_TRUE(base::ReadFileToString( base::FilePath::FromUTF8Unsafe("/proc/net/xt_qtaguid/stats"), &contents)); for (size_t i = contents.find('\n'); // Skip first line which is headers. i != std::string::npos && i < contents.length();) { uint64_t tag, rx_bytes; uid_t uid; int n; // Parse out the numbers we care about. For reference here's the column // headers: // idx iface acct_tag_hex uid_tag_int cnt_set rx_bytes rx_packets tx_bytes // tx_packets rx_tcp_bytes rx_tcp_packets rx_udp_bytes rx_udp_packets // rx_other_bytes rx_other_packets tx_tcp_bytes tx_tcp_packets tx_udp_bytes // tx_udp_packets tx_other_bytes tx_other_packets EXPECT_EQ(sscanf(contents.c_str() + i, "%*d %*s 0x%" SCNx64 " %d %*d %" SCNu64 " %*d %*d %*d %*d %*d %*d %*d %*d " "%*d %*d %*d %*d %*d %*d %*d%n", &tag, &uid, &rx_bytes, &n), 3); // If this line matches our UID and |expected_tag| then add it to the total. if (uid == getuid() && (int32_t)(tag >> 32) == expected_tag) { bytes += rx_bytes; } // Move |i| to the next line. i += n + 1; } return bytes; } #endif } // namespace net