1 // Copyright 2012 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_ 6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_ 7 8 #include <stddef.h> 9 #include <stdint.h> 10 11 #include <cstring> 12 #include <memory> 13 #include <optional> 14 #include <string> 15 #include <string_view> 16 #include <utility> 17 #include <vector> 18 19 #include "base/check_op.h" 20 #include "base/containers/span.h" 21 #include "base/functional/bind.h" 22 #include "base/functional/callback.h" 23 #include "base/memory/ptr_util.h" 24 #include "base/memory/raw_ptr.h" 25 #include "base/memory/ref_counted.h" 26 #include "base/memory/weak_ptr.h" 27 #include "build/build_config.h" 28 #include "net/base/address_list.h" 29 #include "net/base/completion_once_callback.h" 30 #include "net/base/io_buffer.h" 31 #include "net/base/net_errors.h" 32 #include "net/base/test_completion_callback.h" 33 #include "net/http/http_auth_controller.h" 34 #include "net/log/net_log_with_source.h" 35 #include "net/socket/client_socket_factory.h" 36 #include "net/socket/client_socket_handle.h" 37 #include "net/socket/client_socket_pool.h" 38 #include "net/socket/datagram_client_socket.h" 39 #include "net/socket/socket_performance_watcher.h" 40 #include "net/socket/socket_tag.h" 41 #include "net/socket/ssl_client_socket.h" 42 #include "net/socket/transport_client_socket.h" 43 #include "net/socket/transport_client_socket_pool.h" 44 #include "net/ssl/ssl_config_service.h" 45 #include "net/ssl/ssl_info.h" 46 #include "testing/gtest/include/gtest/gtest.h" 47 48 namespace base { 49 class RunLoop; 50 } 51 52 namespace net { 53 54 struct CommonConnectJobParams; 55 class NetLog; 56 struct NetworkTrafficAnnotationTag; 57 class X509Certificate; 58 59 const handles::NetworkHandle kDefaultNetworkForTests = 1; 60 const handles::NetworkHandle kNewNetworkForTests = 2; 61 62 enum { 63 // A private network error code used by the socket test utility classes. 64 // If the |result| member of a MockRead is 65 // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a 66 // marker that indicates the peer will close the connection after the next 67 // MockRead. The other members of that MockRead are ignored. 68 ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000, 69 }; 70 71 class AsyncSocket; 72 class MockClientSocket; 73 class SSLClientSocket; 74 class StreamSocket; 75 76 enum IoMode { ASYNC, SYNCHRONOUS }; 77 78 struct MockConnect { 79 // Asynchronous connection success. 80 // Creates a MockConnect with |mode| ASYC, |result| OK, and 81 // |peer_addr| 192.0.2.33. 82 MockConnect(); 83 // Creates a MockConnect with the specified mode and result, with 84 // |peer_addr| 192.0.2.33. 85 MockConnect(IoMode io_mode, int r); 86 MockConnect(IoMode io_mode, int r, IPEndPoint addr); 87 MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails); 88 ~MockConnect(); 89 90 IoMode mode; 91 int result; 92 IPEndPoint peer_addr; 93 bool first_attempt_fails = false; 94 }; 95 96 struct MockConfirm { 97 // Asynchronous confirm success. 98 // Creates a MockConfirm with |mode| ASYC and |result| OK. 99 MockConfirm(); 100 // Creates a MockConfirm with the specified mode and result. 101 MockConfirm(IoMode io_mode, int r); 102 ~MockConfirm(); 103 104 IoMode mode; 105 int result; 106 }; 107 108 // MockRead and MockWrite shares the same interface and members, but we'd like 109 // to have distinct types because we don't want to have them used 110 // interchangably. To do this, a struct template is defined, and MockRead and 111 // MockWrite are instantiated by using this template. Template parameter |type| 112 // is not used in the struct definition (it purely exists for creating a new 113 // type). 114 // 115 // |data| in MockRead and MockWrite has different meanings: |data| in MockRead 116 // is the data returned from the socket when MockTCPClientSocket::Read() is 117 // attempted, while |data| in MockWrite is the expected data that should be 118 // given in MockTCPClientSocket::Write(). 119 enum MockReadWriteType { MOCK_READ, MOCK_WRITE }; 120 121 template <MockReadWriteType type> 122 struct MockReadWrite { 123 // Flag to indicate that the message loop should be terminated. 124 enum { STOPLOOP = 1 << 31 }; 125 126 // Default MockReadWriteMockReadWrite127 MockReadWrite() 128 : mode(SYNCHRONOUS), 129 result(0), 130 data(nullptr), 131 data_len(0), 132 sequence_number(0), 133 tos(0) {} 134 135 // Read/write failure (no data). MockReadWriteMockReadWrite136 MockReadWrite(IoMode io_mode, int result) 137 : mode(io_mode), 138 result(result), 139 data(nullptr), 140 data_len(0), 141 sequence_number(0), 142 tos(0) {} 143 144 // Read/write failure (no data), with sequence information. MockReadWriteMockReadWrite145 MockReadWrite(IoMode io_mode, int result, int seq) 146 : mode(io_mode), 147 result(result), 148 data(nullptr), 149 data_len(0), 150 sequence_number(seq), 151 tos(0) {} 152 153 // Asynchronous read/write success (inferred data length). MockReadWriteMockReadWrite154 explicit MockReadWrite(const char* data) 155 : mode(ASYNC), 156 result(0), 157 data(data), 158 data_len(strlen(data)), 159 sequence_number(0), 160 tos(0) {} 161 162 // Read/write success (inferred data length). MockReadWriteMockReadWrite163 MockReadWrite(IoMode io_mode, const char* data) 164 : mode(io_mode), 165 result(0), 166 data(data), 167 data_len(strlen(data)), 168 sequence_number(0), 169 tos(0) {} 170 171 // Read/write success. MockReadWriteMockReadWrite172 MockReadWrite(IoMode io_mode, const char* data, int data_len) 173 : mode(io_mode), 174 result(0), 175 data(data), 176 data_len(data_len), 177 sequence_number(0), 178 tos(0) {} 179 180 // Read/write success (inferred data length) with sequence information. MockReadWriteMockReadWrite181 MockReadWrite(IoMode io_mode, int seq, const char* data) 182 : mode(io_mode), 183 result(0), 184 data(data), 185 data_len(strlen(data)), 186 sequence_number(seq), 187 tos(0) {} 188 189 // Read/write success with sequence information. MockReadWriteMockReadWrite190 MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) 191 : mode(io_mode), 192 result(0), 193 data(data), 194 data_len(data_len), 195 sequence_number(seq), 196 tos(0) {} 197 198 // Read/write success with sequence and TOS information. MockReadWriteMockReadWrite199 MockReadWrite(IoMode io_mode, 200 const char* data, 201 int data_len, 202 int seq, 203 uint8_t tos_byte) 204 : mode(io_mode), 205 result(0), 206 data(data), 207 data_len(data_len), 208 sequence_number(seq), 209 tos(tos_byte) {} 210 211 IoMode mode; 212 int result; 213 const char* data; 214 int data_len; 215 216 // For data providers that only allows reads to occur in a particular 217 // sequence. If a read occurs before the given |sequence_number| is reached, 218 // an ERR_IO_PENDING is returned. 219 int sequence_number; // The sequence number at which a read is allowed 220 // to occur. 221 222 // The TOS byte of the datagram, for datagram sockets only. 223 uint8_t tos; 224 }; 225 226 typedef MockReadWrite<MOCK_READ> MockRead; 227 typedef MockReadWrite<MOCK_WRITE> MockWrite; 228 229 struct MockWriteResult { MockWriteResultMockWriteResult230 MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {} 231 232 IoMode mode; 233 int result; 234 }; 235 236 class SocketDataPrinter { 237 public: 238 ~SocketDataPrinter() = default; 239 240 // Prints the write in |data| using some sort of protocol-specific 241 // format. 242 virtual std::string PrintWrite(const std::string& data) = 0; 243 }; 244 245 // The SocketDataProvider is an interface used by the MockClientSocket 246 // for getting data about individual reads and writes on the socket. Can be 247 // used with at most one socket at a time. 248 // TODO(mmenke): Do these really need to be re-useable? 249 class SocketDataProvider { 250 public: 251 SocketDataProvider(); 252 253 SocketDataProvider(const SocketDataProvider&) = delete; 254 SocketDataProvider& operator=(const SocketDataProvider&) = delete; 255 256 virtual ~SocketDataProvider(); 257 258 // Returns the buffer and result code for the next simulated read. 259 // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller 260 // that it will be called via the AsyncSocket::OnReadComplete() 261 // function at a later time. 262 virtual MockRead OnRead() = 0; 263 virtual MockWriteResult OnWrite(const std::string& data) = 0; 264 virtual bool AllReadDataConsumed() const = 0; 265 virtual bool AllWriteDataConsumed() const = 0; CancelPendingRead()266 virtual void CancelPendingRead() {} 267 268 // Returns the last set receive buffer size, or -1 if never set. receive_buffer_size()269 int receive_buffer_size() const { return receive_buffer_size_; } set_receive_buffer_size(int receive_buffer_size)270 void set_receive_buffer_size(int receive_buffer_size) { 271 receive_buffer_size_ = receive_buffer_size; 272 } 273 274 // Returns the last set send buffer size, or -1 if never set. send_buffer_size()275 int send_buffer_size() const { return send_buffer_size_; } set_send_buffer_size(int send_buffer_size)276 void set_send_buffer_size(int send_buffer_size) { 277 send_buffer_size_ = send_buffer_size; 278 } 279 280 // Returns the last set value of TCP no delay, or false if never set. no_delay()281 bool no_delay() const { return no_delay_; } set_no_delay(bool no_delay)282 void set_no_delay(bool no_delay) { no_delay_ = no_delay; } 283 284 // Returns whether TCP keepalives were enabled or not. Returns kDefault by 285 // default. 286 enum class KeepAliveState { kEnabled, kDisabled, kDefault }; keep_alive_state()287 KeepAliveState keep_alive_state() const { return keep_alive_state_; } 288 // Last set TCP keepalive delay. keep_alive_delay()289 int keep_alive_delay() const { return keep_alive_delay_; } set_keep_alive(bool enable,int delay)290 void set_keep_alive(bool enable, int delay) { 291 keep_alive_state_ = 292 enable ? KeepAliveState::kEnabled : KeepAliveState::kDisabled; 293 keep_alive_delay_ = delay; 294 } 295 296 // Setters / getters for the return values of the corresponding Set*() 297 // methods. By default, they all succeed, if the socket is connected. 298 set_set_receive_buffer_size_result(int receive_buffer_size_result)299 void set_set_receive_buffer_size_result(int receive_buffer_size_result) { 300 set_receive_buffer_size_result_ = receive_buffer_size_result; 301 } set_receive_buffer_size_result()302 int set_receive_buffer_size_result() const { 303 return set_receive_buffer_size_result_; 304 } 305 set_set_send_buffer_size_result(int set_send_buffer_size_result)306 void set_set_send_buffer_size_result(int set_send_buffer_size_result) { 307 set_send_buffer_size_result_ = set_send_buffer_size_result; 308 } set_send_buffer_size_result()309 int set_send_buffer_size_result() const { 310 return set_send_buffer_size_result_; 311 } 312 set_set_no_delay_result(bool set_no_delay_result)313 void set_set_no_delay_result(bool set_no_delay_result) { 314 set_no_delay_result_ = set_no_delay_result; 315 } set_no_delay_result()316 bool set_no_delay_result() const { return set_no_delay_result_; } 317 set_set_keep_alive_result(bool set_keep_alive_result)318 void set_set_keep_alive_result(bool set_keep_alive_result) { 319 set_keep_alive_result_ = set_keep_alive_result; 320 } set_keep_alive_result()321 bool set_keep_alive_result() const { return set_keep_alive_result_; } 322 expected_addresses()323 const std::optional<AddressList>& expected_addresses() const { 324 return expected_addresses_; 325 } set_expected_addresses(net::AddressList addresses)326 void set_expected_addresses(net::AddressList addresses) { 327 expected_addresses_ = std::move(addresses); 328 } 329 330 // Returns true if the request should be considered idle, for the purposes of 331 // IsConnectedAndIdle. 332 virtual bool IsIdle() const; 333 334 // Initializes the SocketDataProvider for use with |socket|. Must be called 335 // before use 336 void Initialize(AsyncSocket* socket); 337 // Detaches the socket associated with a SocketDataProvider. Must be called 338 // before |socket_| is destroyed, unless the SocketDataProvider has informed 339 // |socket_| it was destroyed. Must also be called before Initialize() may 340 // be called again with a new socket. 341 void DetachSocket(); 342 343 // Accessor for the socket which is using the SocketDataProvider. socket()344 AsyncSocket* socket() { return socket_; } 345 connect_data()346 MockConnect connect_data() const { return connect_; } set_connect_data(const MockConnect & connect)347 void set_connect_data(const MockConnect& connect) { connect_ = connect; } 348 349 private: 350 // Called to inform subclasses of initialization. 351 virtual void Reset() = 0; 352 353 MockConnect connect_; 354 raw_ptr<AsyncSocket> socket_ = nullptr; 355 356 int receive_buffer_size_ = -1; 357 int send_buffer_size_ = -1; 358 // This reflects the default state of TCPClientSockets. 359 bool no_delay_ = true; 360 361 KeepAliveState keep_alive_state_ = KeepAliveState::kDefault; 362 int keep_alive_delay_ = 0; 363 364 int set_receive_buffer_size_result_ = net::OK; 365 int set_send_buffer_size_result_ = net::OK; 366 bool set_no_delay_result_ = true; 367 bool set_keep_alive_result_ = true; 368 std::optional<AddressList> expected_addresses_; 369 }; 370 371 // The AsyncSocket is an interface used by the SocketDataProvider to 372 // complete the asynchronous read operation. 373 class AsyncSocket { 374 public: 375 // If an async IO is pending because the SocketDataProvider returned 376 // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete 377 // is called to complete the asynchronous read operation. 378 // data.async is ignored, and this read is completed synchronously as 379 // part of this call. 380 // TODO(rch): this should take a StringPiece since most of the fields 381 // are ignored. 382 virtual void OnReadComplete(const MockRead& data) = 0; 383 // If an async IO is pending because the SocketDataProvider returned 384 // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete 385 // is called to complete the asynchronous read operation. 386 virtual void OnWriteComplete(int rv) = 0; 387 virtual void OnConnectComplete(const MockConnect& data) = 0; 388 389 // Called when the SocketDataProvider associated with the socket is destroyed. 390 // The socket may continue to be used after the data provider is destroyed, 391 // so it should be sure not to dereference the provider after this is called. 392 virtual void OnDataProviderDestroyed() = 0; 393 }; 394 395 // StaticSocketDataHelper manages a list of reads and writes. 396 class StaticSocketDataHelper { 397 public: 398 StaticSocketDataHelper(base::span<const MockRead> reads, 399 base::span<const MockWrite> writes); 400 401 StaticSocketDataHelper(const StaticSocketDataHelper&) = delete; 402 StaticSocketDataHelper& operator=(const StaticSocketDataHelper&) = delete; 403 404 ~StaticSocketDataHelper(); 405 406 // These functions get access to the next available read and write data. They 407 // CHECK fail if there is no data available. 408 const MockRead& PeekRead() const; 409 const MockWrite& PeekWrite() const; 410 411 // Returns the current read or write, and then advances to the next one. 412 const MockRead& AdvanceRead(); 413 const MockWrite& AdvanceWrite(); 414 415 // Resets the read and write indexes to 0. 416 void Reset(); 417 418 // Returns true if |data| is valid data for the next write. In order 419 // to support short writes, the next write may be longer than |data| 420 // in which case this method will still return true. 421 bool VerifyWriteData(const std::string& data, SocketDataPrinter* printer); 422 read_index()423 size_t read_index() const { return read_index_; } write_index()424 size_t write_index() const { return write_index_; } read_count()425 size_t read_count() const { return reads_.size(); } write_count()426 size_t write_count() const { return writes_.size(); } 427 AllReadDataConsumed()428 bool AllReadDataConsumed() const { return read_index() >= read_count(); } AllWriteDataConsumed()429 bool AllWriteDataConsumed() const { return write_index() >= write_count(); } 430 431 void ExpectAllReadDataConsumed(SocketDataPrinter* printer) const; 432 void ExpectAllWriteDataConsumed(SocketDataPrinter* printer) const; 433 434 private: 435 // Returns the next available read or write that is not a pause event. CHECK 436 // fails if no data is available. 437 const MockWrite& PeekRealWrite() const; 438 439 const base::span<const MockRead> reads_; 440 size_t read_index_ = 0; 441 const base::span<const MockWrite> writes_; 442 size_t write_index_ = 0; 443 }; 444 445 // SocketDataProvider which responds based on static tables of mock reads and 446 // writes. 447 class StaticSocketDataProvider : public SocketDataProvider { 448 public: 449 StaticSocketDataProvider(); 450 StaticSocketDataProvider(base::span<const MockRead> reads, 451 base::span<const MockWrite> writes); 452 453 StaticSocketDataProvider(const StaticSocketDataProvider&) = delete; 454 StaticSocketDataProvider& operator=(const StaticSocketDataProvider&) = delete; 455 456 ~StaticSocketDataProvider() override; 457 458 // Pause/resume reads from this provider. 459 void Pause(); 460 void Resume(); 461 462 // From SocketDataProvider: 463 MockRead OnRead() override; 464 MockWriteResult OnWrite(const std::string& data) override; 465 bool AllReadDataConsumed() const override; 466 bool AllWriteDataConsumed() const override; 467 read_index()468 size_t read_index() const { return helper_.read_index(); } write_index()469 size_t write_index() const { return helper_.write_index(); } read_count()470 size_t read_count() const { return helper_.read_count(); } write_count()471 size_t write_count() const { return helper_.write_count(); } 472 set_printer(SocketDataPrinter * printer)473 void set_printer(SocketDataPrinter* printer) { printer_ = printer; } 474 475 private: 476 // From SocketDataProvider: 477 void Reset() override; 478 479 StaticSocketDataHelper helper_; 480 raw_ptr<SocketDataPrinter> printer_ = nullptr; 481 bool paused_ = false; 482 }; 483 484 // SSLSocketDataProviders only need to keep track of the return code from calls 485 // to Connect(). 486 struct SSLSocketDataProvider { 487 SSLSocketDataProvider(IoMode mode, int result); 488 SSLSocketDataProvider(const SSLSocketDataProvider& other); 489 ~SSLSocketDataProvider(); 490 491 // Returns whether MockConnect data has been consumed. ConnectDataConsumedSSLSocketDataProvider492 bool ConnectDataConsumed() const { return is_connect_data_consumed; } 493 494 // Returns whether MockConfirm data has been consumed. ConfirmDataConsumedSSLSocketDataProvider495 bool ConfirmDataConsumed() const { return is_confirm_data_consumed; } 496 497 // Returns whether a Write occurred before ConfirmHandshake completed. WriteBeforeConfirmSSLSocketDataProvider498 bool WriteBeforeConfirm() const { return write_called_before_confirm; } 499 500 // Result for Connect(). 501 MockConnect connect; 502 // Callback to run when Connect() is called. This is called at most once per 503 // socket but is repeating because SSLSocketDataProvider is copyable. 504 base::RepeatingClosure connect_callback; 505 506 // Result for ConfirmHandshake(). 507 MockConfirm confirm; 508 // Callback to run when ConfirmHandshake() is called. This is called at most 509 // once per socket but is repeating because SSLSocketDataProvider is 510 // copyable. 511 base::RepeatingClosure confirm_callback; 512 513 // Result for GetNegotiatedProtocol(). 514 NextProto next_proto = kProtoUnknown; 515 516 // Result for GetPeerApplicationSettings(). 517 std::optional<std::string> peer_application_settings; 518 519 // Result for GetSSLInfo(). 520 SSLInfo ssl_info; 521 522 // Result for GetSSLCertRequestInfo(). 523 scoped_refptr<SSLCertRequestInfo> cert_request_info; 524 525 // Result for GetECHRetryConfigs(). 526 std::vector<uint8_t> ech_retry_configs; 527 528 std::optional<NextProtoVector> next_protos_expected_in_ssl_config; 529 std::optional<SSLConfig::ApplicationSettings> expected_application_settings; 530 531 uint16_t expected_ssl_version_min; 532 uint16_t expected_ssl_version_max; 533 std::optional<bool> expected_early_data_enabled; 534 std::optional<bool> expected_send_client_cert; 535 scoped_refptr<X509Certificate> expected_client_cert; 536 std::optional<HostPortPair> expected_host_and_port; 537 std::optional<bool> expected_ignore_certificate_errors; 538 std::optional<NetworkAnonymizationKey> expected_network_anonymization_key; 539 std::optional<std::vector<uint8_t>> expected_ech_config_list; 540 541 bool is_connect_data_consumed = false; 542 bool is_confirm_data_consumed = false; 543 bool write_called_before_confirm = false; 544 }; 545 546 // Uses the sequence_number field in the mock reads and writes to 547 // complete the operations in a specified order. 548 class SequencedSocketData : public SocketDataProvider { 549 public: 550 SequencedSocketData(); 551 552 // |reads| is the list of MockRead completions. 553 // |writes| is the list of MockWrite completions. 554 SequencedSocketData(base::span<const MockRead> reads, 555 base::span<const MockWrite> writes); 556 557 // |connect| is the result for the connect phase. 558 // |reads| is the list of MockRead completions. 559 // |writes| is the list of MockWrite completions. 560 SequencedSocketData(const MockConnect& connect, 561 base::span<const MockRead> reads, 562 base::span<const MockWrite> writes); 563 564 SequencedSocketData(const SequencedSocketData&) = delete; 565 SequencedSocketData& operator=(const SequencedSocketData&) = delete; 566 567 ~SequencedSocketData() override; 568 569 // From SocketDataProvider: 570 MockRead OnRead() override; 571 MockWriteResult OnWrite(const std::string& data) override; 572 bool AllReadDataConsumed() const override; 573 bool AllWriteDataConsumed() const override; 574 bool IsIdle() const override; 575 void CancelPendingRead() override; 576 577 // EXPECTs that all data has been consumed, printing any un-consumed data. 578 void ExpectAllReadDataConsumed() const; 579 void ExpectAllWriteDataConsumed() const; 580 581 // An ASYNC read event with a return value of ERR_IO_PENDING will cause the 582 // socket data to pause at that event, and advance no further, until Resume is 583 // invoked. At that point, the socket will continue at the next event in the 584 // sequence. 585 // 586 // If a request just wants to simulate a connection that stays open and never 587 // receives any more data, instead of pausing and then resuming a request, it 588 // should use a SYNCHRONOUS event with a return value of ERR_IO_PENDING 589 // instead. 590 bool IsPaused() const; 591 // Resumes events once |this| is in the paused state. The next event will 592 // occur synchronously with the call if it can. 593 void Resume(); 594 void RunUntilPaused(); 595 596 // When true, IsConnectedAndIdle() will return false if the next event in the 597 // sequence is a synchronous. Otherwise, the socket claims to be idle as 598 // long as it's connected. Defaults to false. 599 // TODO(mmenke): See if this can be made the default behavior, and consider 600 // removing this mehtod. Need to make sure it doesn't change what code any 601 // tests are targetted at testing. set_busy_before_sync_reads(bool busy_before_sync_reads)602 void set_busy_before_sync_reads(bool busy_before_sync_reads) { 603 busy_before_sync_reads_ = busy_before_sync_reads; 604 } 605 set_printer(SocketDataPrinter * printer)606 void set_printer(SocketDataPrinter* printer) { printer_ = printer; } 607 608 private: 609 // Defines the state for the read or write path. 610 enum class IoState { 611 kIdle, // No async operation is in progress. 612 kPending, // An async operation in waiting for another operation to 613 // complete. 614 kCompleting, // A task has been posted to complete an async operation. 615 kPaused, // IO is paused until Resume() is called. 616 }; 617 618 // From SocketDataProvider: 619 void Reset() override; 620 621 void OnReadComplete(); 622 void OnWriteComplete(); 623 624 void MaybePostReadCompleteTask(); 625 void MaybePostWriteCompleteTask(); 626 627 StaticSocketDataHelper helper_; 628 raw_ptr<SocketDataPrinter> printer_ = nullptr; 629 int sequence_number_ = 0; 630 IoState read_state_ = IoState::kIdle; 631 IoState write_state_ = IoState::kIdle; 632 633 bool busy_before_sync_reads_ = false; 634 635 // Used by RunUntilPaused. NULL at all other times. 636 std::unique_ptr<base::RunLoop> run_until_paused_run_loop_; 637 638 base::WeakPtrFactory<SequencedSocketData> weak_factory_{this}; 639 }; 640 641 // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}StreamSocket 642 // objects get instantiated, they take their data from the i'th element of this 643 // array. 644 template <typename T> 645 class SocketDataProviderArray { 646 public: 647 SocketDataProviderArray() = default; 648 GetNext()649 T* GetNext() { 650 DCHECK_LT(next_index_, data_providers_.size()); 651 return data_providers_[next_index_++]; 652 } 653 654 // Like GetNext(), but returns nullptr when the end of the array is reached, 655 // instead of DCHECKing. GetNext() should generally be preferred, unless 656 // having no remaining elements is expected in some cases and is handled 657 // safely. GetNextWithoutAsserting()658 T* GetNextWithoutAsserting() { 659 if (next_index_ == data_providers_.size()) 660 return nullptr; 661 return data_providers_[next_index_++]; 662 } 663 Add(T * data_provider)664 void Add(T* data_provider) { 665 DCHECK(data_provider); 666 data_providers_.push_back(data_provider); 667 } 668 next_index()669 size_t next_index() { return next_index_; } 670 ResetNextIndex()671 void ResetNextIndex() { next_index_ = 0; } 672 673 private: 674 // Index of the next |data_providers_| element to use. Not an iterator 675 // because those are invalidated on vector reallocation. 676 size_t next_index_ = 0; 677 678 // SocketDataProviders to be returned. 679 std::vector<T*> data_providers_; 680 }; 681 682 class MockUDPClientSocket; 683 class MockTCPClientSocket; 684 class MockSSLClientSocket; 685 686 // ClientSocketFactory which contains arrays of sockets of each type. 687 // You should first fill the arrays using Add{SSL,}SocketDataProvider(). When 688 // the factory is asked to create a socket, it takes next entry from appropriate 689 // array. You can use ResetNextMockIndexes to reset that next entry index for 690 // all mock socket types. 691 class MockClientSocketFactory : public ClientSocketFactory { 692 public: 693 MockClientSocketFactory(); 694 695 MockClientSocketFactory(const MockClientSocketFactory&) = delete; 696 MockClientSocketFactory& operator=(const MockClientSocketFactory&) = delete; 697 698 ~MockClientSocketFactory() override; 699 700 // Adds a SocketDataProvider that can be used to served either TCP or UDP 701 // connection requests. Sockets are returned in FIFO order. 702 void AddSocketDataProvider(SocketDataProvider* socket); 703 704 // Like AddSocketDataProvider(), except sockets will only be used to service 705 // TCP connection requests. Sockets added with this method are used first, 706 // before sockets added with AddSocketDataProvider(). Particularly useful for 707 // QUIC tests with multiple sockets, where TCP connections may or may not be 708 // made, and have no guaranteed order, relative to UDP connections. 709 void AddTcpSocketDataProvider(SocketDataProvider* socket); 710 711 void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); 712 void ResetNextMockIndexes(); 713 mock_data()714 SocketDataProviderArray<SocketDataProvider>& mock_data() { 715 return mock_data_; 716 } 717 set_enable_read_if_ready(bool enable_read_if_ready)718 void set_enable_read_if_ready(bool enable_read_if_ready) { 719 enable_read_if_ready_ = enable_read_if_ready; 720 } 721 722 // ClientSocketFactory 723 std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket( 724 DatagramSocket::BindType bind_type, 725 NetLog* net_log, 726 const NetLogSource& source) override; 727 std::unique_ptr<TransportClientSocket> CreateTransportClientSocket( 728 const AddressList& addresses, 729 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher, 730 NetworkQualityEstimator* network_quality_estimator, 731 NetLog* net_log, 732 const NetLogSource& source) override; 733 std::unique_ptr<SSLClientSocket> CreateSSLClientSocket( 734 SSLClientContext* context, 735 std::unique_ptr<StreamSocket> stream_socket, 736 const HostPortPair& host_and_port, 737 const SSLConfig& ssl_config) override; udp_client_socket_ports()738 const std::vector<uint16_t>& udp_client_socket_ports() const { 739 return udp_client_socket_ports_; 740 } 741 742 private: 743 SocketDataProviderArray<SocketDataProvider> mock_data_; 744 SocketDataProviderArray<SocketDataProvider> mock_tcp_data_; 745 SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; 746 std::vector<uint16_t> udp_client_socket_ports_; 747 748 // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns 749 // ERR_READ_IF_READY_NOT_IMPLEMENTED. 750 bool enable_read_if_ready_ = false; 751 }; 752 753 class MockClientSocket : public TransportClientSocket { 754 public: 755 // The NetLogWithSource is needed to test LoadTimingInfo, which uses NetLog 756 // IDs as 757 // unique socket IDs. 758 explicit MockClientSocket(const NetLogWithSource& net_log); 759 760 MockClientSocket(const MockClientSocket&) = delete; 761 MockClientSocket& operator=(const MockClientSocket&) = delete; 762 763 // Socket implementation. 764 int Read(IOBuffer* buf, 765 int buf_len, 766 CompletionOnceCallback callback) override = 0; 767 int Write(IOBuffer* buf, 768 int buf_len, 769 CompletionOnceCallback callback, 770 const NetworkTrafficAnnotationTag& traffic_annotation) override = 0; 771 int SetReceiveBufferSize(int32_t size) override; 772 int SetSendBufferSize(int32_t size) override; 773 774 // TransportClientSocket implementation. 775 int Bind(const net::IPEndPoint& local_addr) override; 776 bool SetNoDelay(bool no_delay) override; 777 bool SetKeepAlive(bool enable, int delay) override; 778 779 // StreamSocket implementation. 780 int Connect(CompletionOnceCallback callback) override = 0; 781 void Disconnect() override; 782 bool IsConnected() const override; 783 bool IsConnectedAndIdle() const override; 784 int GetPeerAddress(IPEndPoint* address) const override; 785 int GetLocalAddress(IPEndPoint* address) const override; 786 const NetLogWithSource& NetLog() const override; 787 NextProto GetNegotiatedProtocol() const override; 788 int64_t GetTotalReceivedBytes() const override; ApplySocketTag(const SocketTag & tag)789 void ApplySocketTag(const SocketTag& tag) override {} 790 791 protected: 792 ~MockClientSocket() override; 793 void RunCallbackAsync(CompletionOnceCallback callback, int result); 794 void RunCallback(CompletionOnceCallback callback, int result); 795 796 // True if Connect completed successfully and Disconnect hasn't been called. 797 bool connected_ = false; 798 799 IPEndPoint local_addr_; 800 IPEndPoint peer_addr_; 801 802 NetLogWithSource net_log_; 803 804 private: 805 base::WeakPtrFactory<MockClientSocket> weak_factory_{this}; 806 }; 807 808 class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { 809 public: 810 MockTCPClientSocket(const AddressList& addresses, 811 net::NetLog* net_log, 812 SocketDataProvider* socket); 813 814 MockTCPClientSocket(const MockTCPClientSocket&) = delete; 815 MockTCPClientSocket& operator=(const MockTCPClientSocket&) = delete; 816 817 ~MockTCPClientSocket() override; 818 addresses()819 const AddressList& addresses() const { return addresses_; } 820 821 // Socket implementation. 822 int Read(IOBuffer* buf, 823 int buf_len, 824 CompletionOnceCallback callback) override; 825 int ReadIfReady(IOBuffer* buf, 826 int buf_len, 827 CompletionOnceCallback callback) override; 828 int CancelReadIfReady() override; 829 int Write(IOBuffer* buf, 830 int buf_len, 831 CompletionOnceCallback callback, 832 const NetworkTrafficAnnotationTag& traffic_annotation) override; 833 int SetReceiveBufferSize(int32_t size) override; 834 int SetSendBufferSize(int32_t size) override; 835 836 // TransportClientSocket implementation. 837 bool SetNoDelay(bool no_delay) override; 838 bool SetKeepAlive(bool enable, int delay) override; 839 840 // StreamSocket implementation. 841 void SetBeforeConnectCallback( 842 const BeforeConnectCallback& before_connect_callback) override; 843 int Connect(CompletionOnceCallback callback) override; 844 void Disconnect() override; 845 bool IsConnected() const override; 846 bool IsConnectedAndIdle() const override; 847 int GetPeerAddress(IPEndPoint* address) const override; 848 bool WasEverUsed() const override; 849 bool GetSSLInfo(SSLInfo* ssl_info) override; 850 851 // AsyncSocket: 852 void OnReadComplete(const MockRead& data) override; 853 void OnWriteComplete(int rv) override; 854 void OnConnectComplete(const MockConnect& data) override; 855 void OnDataProviderDestroyed() override; 856 set_enable_read_if_ready(bool enable_read_if_ready)857 void set_enable_read_if_ready(bool enable_read_if_ready) { 858 enable_read_if_ready_ = enable_read_if_ready; 859 } 860 861 private: 862 void RetryRead(int rv); 863 int ReadIfReadyImpl(IOBuffer* buf, 864 int buf_len, 865 CompletionOnceCallback callback); 866 867 // Helper method to run |pending_read_if_ready_callback_| if it is not null. 868 void RunReadIfReadyCallback(int result); 869 870 AddressList addresses_; 871 872 raw_ptr<SocketDataProvider> data_; 873 int read_offset_ = 0; 874 MockRead read_data_; 875 bool need_read_data_ = true; 876 877 // True if the peer has closed the connection. This allows us to simulate 878 // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real 879 // TCPClientSocket. 880 bool peer_closed_connection_ = false; 881 882 // While an asynchronous read is pending, we save our user-buffer state. 883 scoped_refptr<IOBuffer> pending_read_buf_ = nullptr; 884 int pending_read_buf_len_ = 0; 885 CompletionOnceCallback pending_read_callback_; 886 887 // Non-null when a ReadIfReady() is pending. 888 CompletionOnceCallback pending_read_if_ready_callback_; 889 890 CompletionOnceCallback pending_connect_callback_; 891 CompletionOnceCallback pending_write_callback_; 892 bool was_used_to_convey_data_ = false; 893 894 // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns 895 // ERR_READ_IF_READY_NOT_IMPLEMENTED. 896 bool enable_read_if_ready_ = false; 897 898 BeforeConnectCallback before_connect_callback_; 899 }; 900 901 class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket { 902 public: 903 MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket, 904 const HostPortPair& host_and_port, 905 const SSLConfig& ssl_config, 906 SSLSocketDataProvider* socket); 907 908 MockSSLClientSocket(const MockSSLClientSocket&) = delete; 909 MockSSLClientSocket& operator=(const MockSSLClientSocket&) = delete; 910 911 ~MockSSLClientSocket() override; 912 913 // Socket implementation. 914 int Read(IOBuffer* buf, 915 int buf_len, 916 CompletionOnceCallback callback) override; 917 int ReadIfReady(IOBuffer* buf, 918 int buf_len, 919 CompletionOnceCallback callback) override; 920 int Write(IOBuffer* buf, 921 int buf_len, 922 CompletionOnceCallback callback, 923 const NetworkTrafficAnnotationTag& traffic_annotation) override; 924 int CancelReadIfReady() override; 925 926 // StreamSocket implementation. 927 int Connect(CompletionOnceCallback callback) override; 928 void Disconnect() override; 929 int ConfirmHandshake(CompletionOnceCallback callback) override; 930 bool IsConnected() const override; 931 bool IsConnectedAndIdle() const override; 932 bool WasEverUsed() const override; 933 int GetPeerAddress(IPEndPoint* address) const override; 934 int GetLocalAddress(IPEndPoint* address) const override; 935 NextProto GetNegotiatedProtocol() const override; 936 std::optional<std::string_view> GetPeerApplicationSettings() const override; 937 bool GetSSLInfo(SSLInfo* ssl_info) override; 938 void GetSSLCertRequestInfo( 939 SSLCertRequestInfo* cert_request_info) const override; 940 void ApplySocketTag(const SocketTag& tag) override; 941 const NetLogWithSource& NetLog() const override; 942 int64_t GetTotalReceivedBytes() const override; 943 int SetReceiveBufferSize(int32_t size) override; 944 int SetSendBufferSize(int32_t size) override; 945 946 // SSLSocket implementation. 947 int ExportKeyingMaterial(std::string_view label, 948 bool has_context, 949 std::string_view context, 950 unsigned char* out, 951 unsigned int outlen) override; 952 953 // SSLClientSocket implementation. 954 std::vector<uint8_t> GetECHRetryConfigs() override; 955 956 // This MockSocket does not implement the manual async IO feature. 957 void OnReadComplete(const MockRead& data) override; 958 void OnWriteComplete(int rv) override; 959 void OnConnectComplete(const MockConnect& data) override; 960 // SSL sockets don't need magic to deal with destruction of their data 961 // provider. 962 // TODO(mmenke): Probably a good idea to support it, anyways. OnDataProviderDestroyed()963 void OnDataProviderDestroyed() override {} 964 965 private: 966 static void ConnectCallback(MockSSLClientSocket* ssl_client_socket, 967 CompletionOnceCallback callback, 968 int rv); 969 970 void RunCallbackAsync(CompletionOnceCallback callback, int result); 971 void RunCallback(CompletionOnceCallback callback, int result); 972 973 void RunConfirmHandshakeCallback(CompletionOnceCallback callback, int result); 974 975 bool connected_ = false; 976 bool in_confirm_handshake_ = false; 977 NetLogWithSource net_log_; 978 std::unique_ptr<StreamSocket> stream_socket_; 979 raw_ptr<SSLSocketDataProvider, AcrossTasksDanglingUntriaged> data_; 980 // Address of the "remote" peer we're connected to. 981 IPEndPoint peer_addr_; 982 983 base::WeakPtrFactory<MockSSLClientSocket> weak_factory_{this}; 984 }; 985 986 class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket { 987 public: 988 explicit MockUDPClientSocket(SocketDataProvider* data = nullptr, 989 net::NetLog* net_log = nullptr); 990 991 MockUDPClientSocket(const MockUDPClientSocket&) = delete; 992 MockUDPClientSocket& operator=(const MockUDPClientSocket&) = delete; 993 994 ~MockUDPClientSocket() override; 995 996 // Socket implementation. 997 int Read(IOBuffer* buf, 998 int buf_len, 999 CompletionOnceCallback callback) override; 1000 int Write(IOBuffer* buf, 1001 int buf_len, 1002 CompletionOnceCallback callback, 1003 const NetworkTrafficAnnotationTag& traffic_annotation) override; 1004 1005 int SetReceiveBufferSize(int32_t size) override; 1006 int SetSendBufferSize(int32_t size) override; 1007 int SetDoNotFragment() override; 1008 int SetRecvTos() override; 1009 int SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) override; 1010 1011 // DatagramSocket implementation. 1012 void Close() override; 1013 int GetPeerAddress(IPEndPoint* address) const override; 1014 int GetLocalAddress(IPEndPoint* address) const override; 1015 void UseNonBlockingIO() override; 1016 int SetMulticastInterface(uint32_t interface_index) override; 1017 const NetLogWithSource& NetLog() const override; 1018 1019 // DatagramClientSocket implementation. 1020 int Connect(const IPEndPoint& address) override; 1021 int ConnectUsingNetwork(handles::NetworkHandle network, 1022 const IPEndPoint& address) override; 1023 int ConnectUsingDefaultNetwork(const IPEndPoint& address) override; 1024 int ConnectAsync(const IPEndPoint& address, 1025 CompletionOnceCallback callback) override; 1026 int ConnectUsingNetworkAsync(handles::NetworkHandle network, 1027 const IPEndPoint& address, 1028 CompletionOnceCallback callback) override; 1029 int ConnectUsingDefaultNetworkAsync(const IPEndPoint& address, 1030 CompletionOnceCallback callback) override; 1031 handles::NetworkHandle GetBoundNetwork() const override; 1032 void ApplySocketTag(const SocketTag& tag) override; SetMsgConfirm(bool confirm)1033 void SetMsgConfirm(bool confirm) override {} 1034 DscpAndEcn GetLastTos() const override; 1035 1036 // AsyncSocket implementation. 1037 void OnReadComplete(const MockRead& data) override; 1038 void OnWriteComplete(int rv) override; 1039 void OnConnectComplete(const MockConnect& data) override; 1040 void OnDataProviderDestroyed() override; 1041 set_source_port(uint16_t port)1042 void set_source_port(uint16_t port) { source_port_ = port; } source_port()1043 uint16_t source_port() const { return source_port_; } set_source_host(IPAddress addr)1044 void set_source_host(IPAddress addr) { source_host_ = addr; } source_host()1045 IPAddress source_host() const { return source_host_; } 1046 1047 // Returns last tag applied to socket. tag()1048 SocketTag tag() const { return tag_; } 1049 1050 // Returns false if socket's tag was changed after the socket was used for 1051 // data transfer (e.g. Read/Write() called), otherwise returns true. tagged_before_data_transferred()1052 bool tagged_before_data_transferred() const { 1053 return tagged_before_data_transferred_; 1054 } 1055 1056 private: 1057 int CompleteRead(); 1058 1059 void RunCallbackAsync(CompletionOnceCallback callback, int result); 1060 void RunCallback(CompletionOnceCallback callback, int result); 1061 1062 bool connected_ = false; 1063 raw_ptr<SocketDataProvider> data_; 1064 int read_offset_ = 0; 1065 MockRead read_data_; 1066 bool need_read_data_ = true; 1067 IPAddress source_host_; 1068 uint16_t source_port_ = 123; // Ephemeral source port. 1069 1070 // Address of the "remote" peer we're connected to. 1071 IPEndPoint peer_addr_; 1072 1073 // Network that the socket is bound to. 1074 handles::NetworkHandle network_ = handles::kInvalidNetworkHandle; 1075 1076 // While an asynchronous IO is pending, we save our user-buffer state. 1077 scoped_refptr<IOBuffer> pending_read_buf_ = nullptr; 1078 int pending_read_buf_len_ = 0; 1079 CompletionOnceCallback pending_read_callback_; 1080 CompletionOnceCallback pending_write_callback_; 1081 1082 NetLogWithSource net_log_; 1083 1084 DatagramBuffers unwritten_buffers_; 1085 1086 SocketTag tag_; 1087 bool data_transferred_ = false; 1088 bool tagged_before_data_transferred_ = true; 1089 1090 uint8_t last_tos_ = 0; 1091 1092 base::WeakPtrFactory<MockUDPClientSocket> weak_factory_{this}; 1093 }; 1094 1095 class TestSocketRequest : public TestCompletionCallbackBase { 1096 public: 1097 TestSocketRequest(std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>* 1098 request_order, 1099 size_t* completion_count); 1100 1101 TestSocketRequest(const TestSocketRequest&) = delete; 1102 TestSocketRequest& operator=(const TestSocketRequest&) = delete; 1103 1104 ~TestSocketRequest() override; 1105 handle()1106 ClientSocketHandle* handle() { return &handle_; } 1107 callback()1108 CompletionOnceCallback callback() { 1109 return base::BindOnce(&TestSocketRequest::OnComplete, 1110 base::Unretained(this)); 1111 } 1112 1113 private: 1114 void OnComplete(int result); 1115 1116 ClientSocketHandle handle_; 1117 raw_ptr<std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>> 1118 request_order_; 1119 raw_ptr<size_t> completion_count_; 1120 }; 1121 1122 class ClientSocketPoolTest { 1123 public: 1124 enum KeepAlive { 1125 KEEP_ALIVE, 1126 1127 // A socket will be disconnected in addition to handle being reset. 1128 NO_KEEP_ALIVE, 1129 }; 1130 1131 static const int kIndexOutOfBounds; 1132 static const int kRequestNotFound; 1133 1134 ClientSocketPoolTest(); 1135 1136 ClientSocketPoolTest(const ClientSocketPoolTest&) = delete; 1137 ClientSocketPoolTest& operator=(const ClientSocketPoolTest&) = delete; 1138 1139 ~ClientSocketPoolTest(); 1140 1141 template <typename PoolType> StartRequestUsingPool(PoolType * socket_pool,const ClientSocketPool::GroupId & group_id,RequestPriority priority,ClientSocketPool::RespectLimits respect_limits,const scoped_refptr<typename PoolType::SocketParams> & socket_params)1142 int StartRequestUsingPool( 1143 PoolType* socket_pool, 1144 const ClientSocketPool::GroupId& group_id, 1145 RequestPriority priority, 1146 ClientSocketPool::RespectLimits respect_limits, 1147 const scoped_refptr<typename PoolType::SocketParams>& socket_params) { 1148 DCHECK(socket_pool); 1149 TestSocketRequest* request( 1150 new TestSocketRequest(&request_order_, &completion_count_)); 1151 requests_.push_back(base::WrapUnique(request)); 1152 int rv = request->handle()->Init( 1153 group_id, socket_params, std::nullopt /* proxy_annotation_tag */, 1154 priority, SocketTag(), respect_limits, request->callback(), 1155 ClientSocketPool::ProxyAuthCallback(), socket_pool, NetLogWithSource()); 1156 if (rv != ERR_IO_PENDING) 1157 request_order_.push_back(request); 1158 return rv; 1159 } 1160 1161 // Provided there were n requests started, takes |index| in range 1..n 1162 // and returns order in which that request completed, in range 1..n, 1163 // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound 1164 // if that request did not complete (for example was canceled). 1165 int GetOrderOfRequest(size_t index) const; 1166 1167 // Resets first initialized socket handle from |requests_|. If found such 1168 // a handle, returns true. 1169 bool ReleaseOneConnection(KeepAlive keep_alive); 1170 1171 // Releases connections until there is nothing to release. 1172 void ReleaseAllConnections(KeepAlive keep_alive); 1173 1174 // Note that this uses 0-based indices, while GetOrderOfRequest takes and 1175 // returns 1-based indices. request(int i)1176 TestSocketRequest* request(int i) { return requests_[i].get(); } 1177 requests_size()1178 size_t requests_size() const { return requests_.size(); } requests()1179 std::vector<std::unique_ptr<TestSocketRequest>>* requests() { 1180 return &requests_; 1181 } completion_count()1182 size_t completion_count() const { return completion_count_; } 1183 1184 private: 1185 std::vector<std::unique_ptr<TestSocketRequest>> requests_; 1186 std::vector<raw_ptr<TestSocketRequest, VectorExperimental>> request_order_; 1187 size_t completion_count_ = 0; 1188 }; 1189 1190 class MockTransportSocketParams 1191 : public base::RefCounted<MockTransportSocketParams> { 1192 public: 1193 MockTransportSocketParams(const MockTransportSocketParams&) = delete; 1194 MockTransportSocketParams& operator=(const MockTransportSocketParams&) = 1195 delete; 1196 1197 private: 1198 friend class base::RefCounted<MockTransportSocketParams>; 1199 ~MockTransportSocketParams() = default; 1200 }; 1201 1202 class MockTransportClientSocketPool : public TransportClientSocketPool { 1203 public: 1204 class MockConnectJob { 1205 public: 1206 MockConnectJob(std::unique_ptr<StreamSocket> socket, 1207 ClientSocketHandle* handle, 1208 const SocketTag& socket_tag, 1209 CompletionOnceCallback callback, 1210 RequestPriority priority); 1211 1212 MockConnectJob(const MockConnectJob&) = delete; 1213 MockConnectJob& operator=(const MockConnectJob&) = delete; 1214 1215 ~MockConnectJob(); 1216 1217 int Connect(); 1218 bool CancelHandle(const ClientSocketHandle* handle); 1219 handle()1220 ClientSocketHandle* handle() const { return handle_; } 1221 priority()1222 RequestPriority priority() const { return priority_; } set_priority(RequestPriority priority)1223 void set_priority(RequestPriority priority) { priority_ = priority; } 1224 1225 private: 1226 void OnConnect(int rv); 1227 1228 std::unique_ptr<StreamSocket> socket_; 1229 raw_ptr<ClientSocketHandle> handle_; 1230 const SocketTag socket_tag_; 1231 CompletionOnceCallback user_callback_; 1232 RequestPriority priority_; 1233 }; 1234 1235 MockTransportClientSocketPool( 1236 int max_sockets, 1237 int max_sockets_per_group, 1238 const CommonConnectJobParams* common_connect_job_params); 1239 1240 MockTransportClientSocketPool(const MockTransportClientSocketPool&) = delete; 1241 MockTransportClientSocketPool& operator=( 1242 const MockTransportClientSocketPool&) = delete; 1243 1244 ~MockTransportClientSocketPool() override; 1245 last_request_priority()1246 RequestPriority last_request_priority() const { 1247 return last_request_priority_; 1248 } 1249 requests()1250 const std::vector<std::unique_ptr<MockConnectJob>>& requests() const { 1251 return job_list_; 1252 } 1253 release_count()1254 int release_count() const { return release_count_; } cancel_count()1255 int cancel_count() const { return cancel_count_; } 1256 1257 // TransportClientSocketPool implementation. 1258 int RequestSocket( 1259 const GroupId& group_id, 1260 scoped_refptr<ClientSocketPool::SocketParams> socket_params, 1261 const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 1262 RequestPriority priority, 1263 const SocketTag& socket_tag, 1264 RespectLimits respect_limits, 1265 ClientSocketHandle* handle, 1266 CompletionOnceCallback callback, 1267 const ProxyAuthCallback& on_auth_callback, 1268 const NetLogWithSource& net_log) override; 1269 void SetPriority(const GroupId& group_id, 1270 ClientSocketHandle* handle, 1271 RequestPriority priority) override; 1272 void CancelRequest(const GroupId& group_id, 1273 ClientSocketHandle* handle, 1274 bool cancel_connect_job) override; 1275 void ReleaseSocket(const GroupId& group_id, 1276 std::unique_ptr<StreamSocket> socket, 1277 int64_t generation) override; 1278 1279 private: 1280 raw_ptr<ClientSocketFactory> client_socket_factory_; 1281 std::vector<std::unique_ptr<MockConnectJob>> job_list_; 1282 RequestPriority last_request_priority_ = DEFAULT_PRIORITY; 1283 int release_count_ = 0; 1284 int cancel_count_ = 0; 1285 }; 1286 1287 // WrappedStreamSocket is a base class that wraps an existing StreamSocket, 1288 // forwarding the Socket and StreamSocket interfaces to the underlying 1289 // transport. 1290 // This is to provide a common base class for subclasses to override specific 1291 // StreamSocket methods for testing, while still communicating with a 'real' 1292 // StreamSocket. 1293 class WrappedStreamSocket : public TransportClientSocket { 1294 public: 1295 explicit WrappedStreamSocket(std::unique_ptr<StreamSocket> transport); 1296 ~WrappedStreamSocket() override; 1297 1298 // StreamSocket implementation: 1299 int Bind(const net::IPEndPoint& local_addr) override; 1300 int Connect(CompletionOnceCallback callback) override; 1301 void Disconnect() override; 1302 bool IsConnected() const override; 1303 bool IsConnectedAndIdle() const override; 1304 int GetPeerAddress(IPEndPoint* address) const override; 1305 int GetLocalAddress(IPEndPoint* address) const override; 1306 const NetLogWithSource& NetLog() const override; 1307 bool WasEverUsed() const override; 1308 NextProto GetNegotiatedProtocol() const override; 1309 bool GetSSLInfo(SSLInfo* ssl_info) override; 1310 int64_t GetTotalReceivedBytes() const override; 1311 void ApplySocketTag(const SocketTag& tag) override; 1312 1313 // Socket implementation: 1314 int Read(IOBuffer* buf, 1315 int buf_len, 1316 CompletionOnceCallback callback) override; 1317 int ReadIfReady(IOBuffer* buf, 1318 int buf_len, 1319 CompletionOnceCallback callback) override; 1320 int Write(IOBuffer* buf, 1321 int buf_len, 1322 CompletionOnceCallback callback, 1323 const NetworkTrafficAnnotationTag& traffic_annotation) override; 1324 int SetReceiveBufferSize(int32_t size) override; 1325 int SetSendBufferSize(int32_t size) override; 1326 1327 protected: 1328 std::unique_ptr<StreamSocket> transport_; 1329 }; 1330 1331 // StreamSocket that wraps another StreamSocket, but keeps track of any 1332 // SocketTag applied to the socket. 1333 class MockTaggingStreamSocket : public WrappedStreamSocket { 1334 public: MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)1335 explicit MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport) 1336 : WrappedStreamSocket(std::move(transport)) {} 1337 1338 MockTaggingStreamSocket(const MockTaggingStreamSocket&) = delete; 1339 MockTaggingStreamSocket& operator=(const MockTaggingStreamSocket&) = delete; 1340 1341 ~MockTaggingStreamSocket() override = default; 1342 1343 // StreamSocket implementation. 1344 int Connect(CompletionOnceCallback callback) override; 1345 void ApplySocketTag(const SocketTag& tag) override; 1346 1347 // Returns false if socket's tag was changed after the socket was connected, 1348 // otherwise returns true. tagged_before_connected()1349 bool tagged_before_connected() const { return tagged_before_connected_; } 1350 1351 // Returns last tag applied to socket. tag()1352 SocketTag tag() const { return tag_; } 1353 1354 private: 1355 bool connected_ = false; 1356 bool tagged_before_connected_ = true; 1357 SocketTag tag_; 1358 }; 1359 1360 // Extend MockClientSocketFactory to return MockTaggingStreamSockets and 1361 // keep track of last socket produced for test inspection. 1362 class MockTaggingClientSocketFactory : public MockClientSocketFactory { 1363 public: 1364 MockTaggingClientSocketFactory() = default; 1365 1366 MockTaggingClientSocketFactory(const MockTaggingClientSocketFactory&) = 1367 delete; 1368 MockTaggingClientSocketFactory& operator=( 1369 const MockTaggingClientSocketFactory&) = delete; 1370 1371 // ClientSocketFactory implementation. 1372 std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket( 1373 DatagramSocket::BindType bind_type, 1374 NetLog* net_log, 1375 const NetLogSource& source) override; 1376 std::unique_ptr<TransportClientSocket> CreateTransportClientSocket( 1377 const AddressList& addresses, 1378 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher, 1379 NetworkQualityEstimator* network_quality_estimator, 1380 NetLog* net_log, 1381 const NetLogSource& source) override; 1382 1383 // These methods return pointers to last TCP and UDP sockets produced by this 1384 // factory. NOTE: Socket must still exist, or pointer will be to freed memory. GetLastProducedTCPSocket()1385 MockTaggingStreamSocket* GetLastProducedTCPSocket() const { 1386 return tcp_socket_; 1387 } GetLastProducedUDPSocket()1388 MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; } 1389 1390 private: 1391 raw_ptr<MockTaggingStreamSocket, AcrossTasksDanglingUntriaged> tcp_socket_ = 1392 nullptr; 1393 raw_ptr<MockUDPClientSocket, AcrossTasksDanglingUntriaged> udp_socket_ = 1394 nullptr; 1395 }; 1396 1397 // Host / port used for SOCKS4 test strings. 1398 extern const char kSOCKS4TestHost[]; 1399 extern const int kSOCKS4TestPort; 1400 1401 // Constants for a successful SOCKS v4 handshake (connecting to kSOCKS4TestHost 1402 // on port kSOCKS4TestPort, for the request). 1403 extern const char kSOCKS4OkRequestLocalHostPort80[]; 1404 extern const int kSOCKS4OkRequestLocalHostPort80Length; 1405 1406 extern const char kSOCKS4OkReply[]; 1407 extern const int kSOCKS4OkReplyLength; 1408 1409 // Host / port used for SOCKS5 test strings. 1410 extern const char kSOCKS5TestHost[]; 1411 extern const int kSOCKS5TestPort; 1412 1413 // Constants for a successful SOCKS v5 handshake (connecting to kSOCKS5TestHost 1414 // on port kSOCKS5TestPort, for the request).. 1415 extern const char kSOCKS5GreetRequest[]; 1416 extern const int kSOCKS5GreetRequestLength; 1417 1418 extern const char kSOCKS5GreetResponse[]; 1419 extern const int kSOCKS5GreetResponseLength; 1420 1421 extern const char kSOCKS5OkRequest[]; 1422 extern const int kSOCKS5OkRequestLength; 1423 1424 extern const char kSOCKS5OkResponse[]; 1425 extern const int kSOCKS5OkResponseLength; 1426 1427 // Helper function to get the total data size of the MockReads in |reads|. 1428 int64_t CountReadBytes(base::span<const MockRead> reads); 1429 1430 // Helper function to get the total data size of the MockWrites in |writes|. 1431 int64_t CountWriteBytes(base::span<const MockWrite> writes); 1432 1433 #if BUILDFLAG(IS_ANDROID) 1434 // Returns whether the device supports calling GetTaggedBytes(). 1435 bool CanGetTaggedBytes(); 1436 1437 // Query the system to find out how many bytes were received with tag 1438 // |expected_tag| for our UID. Return the count of received bytes. 1439 uint64_t GetTaggedBytes(int32_t expected_tag); 1440 #endif 1441 1442 } // namespace net 1443 1444 #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ 1445