xref: /aosp_15_r20/external/cronet/net/third_party/quiche/src/quiche/quic/test_tools/quic_test_client.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "quiche/quic/test_tools/quic_test_client.h"
6 
7 #include <memory>
8 #include <utility>
9 #include <vector>
10 
11 #include "absl/strings/match.h"
12 #include "absl/strings/string_view.h"
13 #include "openssl/x509.h"
14 #include "quiche/quic/core/crypto/proof_verifier.h"
15 #include "quiche/quic/core/http/quic_spdy_client_stream.h"
16 #include "quiche/quic/core/http/spdy_utils.h"
17 #include "quiche/quic/core/io/quic_default_event_loop.h"
18 #include "quiche/quic/core/quic_default_clock.h"
19 #include "quiche/quic/core/quic_packet_writer_wrapper.h"
20 #include "quiche/quic/core/quic_server_id.h"
21 #include "quiche/quic/core/quic_stream_priority.h"
22 #include "quiche/quic/core/quic_utils.h"
23 #include "quiche/quic/platform/api/quic_flags.h"
24 #include "quiche/quic/platform/api/quic_logging.h"
25 #include "quiche/quic/platform/api/quic_stack_trace.h"
26 #include "quiche/quic/test_tools/crypto_test_utils.h"
27 #include "quiche/quic/test_tools/quic_connection_peer.h"
28 #include "quiche/quic/test_tools/quic_spdy_session_peer.h"
29 #include "quiche/quic/test_tools/quic_spdy_stream_peer.h"
30 #include "quiche/quic/test_tools/quic_test_utils.h"
31 #include "quiche/quic/tools/quic_url.h"
32 #include "quiche/common/quiche_callbacks.h"
33 #include "quiche/common/quiche_text_utils.h"
34 
35 namespace quic {
36 namespace test {
37 namespace {
38 
39 // RecordingProofVerifier accepts any certificate chain and records the common
40 // name of the leaf and then delegates the actual verification to an actual
41 // verifier. If no optional verifier is provided, then VerifyProof will return
42 // success.
43 class RecordingProofVerifier : public ProofVerifier {
44  public:
RecordingProofVerifier(std::unique_ptr<ProofVerifier> verifier)45   explicit RecordingProofVerifier(std::unique_ptr<ProofVerifier> verifier)
46       : verifier_(std::move(verifier)) {}
47 
48   // ProofVerifier interface.
VerifyProof(const std::string & hostname,const uint16_t port,const std::string & server_config,QuicTransportVersion transport_version,absl::string_view chlo_hash,const std::vector<std::string> & certs,const std::string & cert_sct,const std::string & signature,const ProofVerifyContext * context,std::string * error_details,std::unique_ptr<ProofVerifyDetails> * details,std::unique_ptr<ProofVerifierCallback> callback)49   QuicAsyncStatus VerifyProof(
50       const std::string& hostname, const uint16_t port,
51       const std::string& server_config, QuicTransportVersion transport_version,
52       absl::string_view chlo_hash, const std::vector<std::string>& certs,
53       const std::string& cert_sct, const std::string& signature,
54       const ProofVerifyContext* context, std::string* error_details,
55       std::unique_ptr<ProofVerifyDetails>* details,
56       std::unique_ptr<ProofVerifierCallback> callback) override {
57     QuicAsyncStatus status = ProcessCerts(certs, cert_sct);
58     if (verifier_ == nullptr) {
59       return status;
60     }
61     return verifier_->VerifyProof(hostname, port, server_config,
62                                   transport_version, chlo_hash, certs, cert_sct,
63                                   signature, context, error_details, details,
64                                   std::move(callback));
65   }
66 
VerifyCertChain(const std::string & hostname,const uint16_t port,const std::vector<std::string> & certs,const std::string & ocsp_response,const std::string & cert_sct,const ProofVerifyContext * context,std::string * error_details,std::unique_ptr<ProofVerifyDetails> * details,uint8_t * out_alert,std::unique_ptr<ProofVerifierCallback> callback)67   QuicAsyncStatus VerifyCertChain(
68       const std::string& hostname, const uint16_t port,
69       const std::vector<std::string>& certs, const std::string& ocsp_response,
70       const std::string& cert_sct, const ProofVerifyContext* context,
71       std::string* error_details, std::unique_ptr<ProofVerifyDetails>* details,
72       uint8_t* out_alert,
73       std::unique_ptr<ProofVerifierCallback> callback) override {
74     // Record the cert.
75     QuicAsyncStatus status = ProcessCerts(certs, cert_sct);
76     if (verifier_ == nullptr) {
77       return status;
78     }
79     return verifier_->VerifyCertChain(hostname, port, certs, ocsp_response,
80                                       cert_sct, context, error_details, details,
81                                       out_alert, std::move(callback));
82   }
83 
CreateDefaultContext()84   std::unique_ptr<ProofVerifyContext> CreateDefaultContext() override {
85     return verifier_ != nullptr ? verifier_->CreateDefaultContext() : nullptr;
86   }
87 
common_name() const88   const std::string& common_name() const { return common_name_; }
89 
cert_sct() const90   const std::string& cert_sct() const { return cert_sct_; }
91 
92  private:
ProcessCerts(const std::vector<std::string> & certs,const std::string & cert_sct)93   QuicAsyncStatus ProcessCerts(const std::vector<std::string>& certs,
94                                const std::string& cert_sct) {
95     common_name_.clear();
96     if (certs.empty()) {
97       return QUIC_FAILURE;
98     }
99 
100     // Parse the cert into an X509 structure.
101     const uint8_t* data;
102     data = reinterpret_cast<const uint8_t*>(certs[0].data());
103     bssl::UniquePtr<X509> cert(d2i_X509(nullptr, &data, certs[0].size()));
104     if (!cert.get()) {
105       return QUIC_FAILURE;
106     }
107 
108     // Extract the CN field
109     X509_NAME* subject = X509_get_subject_name(cert.get());
110     const int index = X509_NAME_get_index_by_NID(subject, NID_commonName, -1);
111     if (index < 0) {
112       return QUIC_FAILURE;
113     }
114     ASN1_STRING* name_data =
115         X509_NAME_ENTRY_get_data(X509_NAME_get_entry(subject, index));
116     if (name_data == nullptr) {
117       return QUIC_FAILURE;
118     }
119 
120     // Convert the CN to UTF8, in case the cert represents it in a different
121     // format.
122     unsigned char* buf = nullptr;
123     const int len = ASN1_STRING_to_UTF8(&buf, name_data);
124     if (len <= 0) {
125       return QUIC_FAILURE;
126     }
127     bssl::UniquePtr<unsigned char> deleter(buf);
128 
129     common_name_.assign(reinterpret_cast<const char*>(buf), len);
130     cert_sct_ = cert_sct;
131     return QUIC_SUCCESS;
132   }
133 
134   std::unique_ptr<ProofVerifier> verifier_;
135   std::string common_name_;
136   std::string cert_sct_;
137 };
138 }  // namespace
139 
ProcessPacket(const QuicSocketAddress & self_address,const QuicSocketAddress & peer_address,const QuicReceivedPacket & packet)140 void MockableQuicClientDefaultNetworkHelper::ProcessPacket(
141     const QuicSocketAddress& self_address,
142     const QuicSocketAddress& peer_address, const QuicReceivedPacket& packet) {
143   QuicClientDefaultNetworkHelper::ProcessPacket(self_address, peer_address,
144                                                 packet);
145   if (track_last_incoming_packet_) {
146     last_incoming_packet_ = packet.Clone();
147   }
148 }
149 
CreateUDPSocket(QuicSocketAddress server_address,bool * overflow_supported)150 SocketFd MockableQuicClientDefaultNetworkHelper::CreateUDPSocket(
151     QuicSocketAddress server_address, bool* overflow_supported) {
152   SocketFd fd = QuicClientDefaultNetworkHelper::CreateUDPSocket(
153       server_address, overflow_supported);
154   if (fd < 0) {
155     return fd;
156   }
157 
158   if (socket_fd_configurator_ != nullptr) {
159     socket_fd_configurator_(fd);
160   }
161   return fd;
162 }
163 
164 QuicPacketWriter*
CreateQuicPacketWriter()165 MockableQuicClientDefaultNetworkHelper::CreateQuicPacketWriter() {
166   QuicPacketWriter* writer =
167       QuicClientDefaultNetworkHelper::CreateQuicPacketWriter();
168   if (!test_writer_) {
169     return writer;
170   }
171   test_writer_->set_writer(writer);
172   return test_writer_;
173 }
174 
set_socket_fd_configurator(quiche::MultiUseCallback<void (SocketFd)> socket_fd_configurator)175 void MockableQuicClientDefaultNetworkHelper::set_socket_fd_configurator(
176     quiche::MultiUseCallback<void(SocketFd)> socket_fd_configurator) {
177   socket_fd_configurator_ = std::move(socket_fd_configurator);
178 }
179 
180 const QuicReceivedPacket*
last_incoming_packet()181 MockableQuicClientDefaultNetworkHelper::last_incoming_packet() {
182   return last_incoming_packet_.get();
183 }
184 
set_track_last_incoming_packet(bool track)185 void MockableQuicClientDefaultNetworkHelper::set_track_last_incoming_packet(
186     bool track) {
187   track_last_incoming_packet_ = track;
188 }
189 
UseWriter(QuicPacketWriterWrapper * writer)190 void MockableQuicClientDefaultNetworkHelper::UseWriter(
191     QuicPacketWriterWrapper* writer) {
192   QUICHE_CHECK(test_writer_ == nullptr);
193   test_writer_ = writer;
194 }
195 
set_peer_address(const QuicSocketAddress & address)196 void MockableQuicClientDefaultNetworkHelper::set_peer_address(
197     const QuicSocketAddress& address) {
198   QUICHE_CHECK(test_writer_ != nullptr);
199   test_writer_->set_peer_address(address);
200 }
201 
MockableQuicClient(QuicSocketAddress server_address,const QuicServerId & server_id,const ParsedQuicVersionVector & supported_versions,QuicEventLoop * event_loop)202 MockableQuicClient::MockableQuicClient(
203     QuicSocketAddress server_address, const QuicServerId& server_id,
204     const ParsedQuicVersionVector& supported_versions,
205     QuicEventLoop* event_loop)
206     : MockableQuicClient(server_address, server_id, QuicConfig(),
207                          supported_versions, event_loop) {}
208 
MockableQuicClient(QuicSocketAddress server_address,const QuicServerId & server_id,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,QuicEventLoop * event_loop)209 MockableQuicClient::MockableQuicClient(
210     QuicSocketAddress server_address, const QuicServerId& server_id,
211     const QuicConfig& config, const ParsedQuicVersionVector& supported_versions,
212     QuicEventLoop* event_loop)
213     : MockableQuicClient(server_address, server_id, config, supported_versions,
214                          event_loop, nullptr) {}
215 
MockableQuicClient(QuicSocketAddress server_address,const QuicServerId & server_id,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,QuicEventLoop * event_loop,std::unique_ptr<ProofVerifier> proof_verifier)216 MockableQuicClient::MockableQuicClient(
217     QuicSocketAddress server_address, const QuicServerId& server_id,
218     const QuicConfig& config, const ParsedQuicVersionVector& supported_versions,
219     QuicEventLoop* event_loop, std::unique_ptr<ProofVerifier> proof_verifier)
220     : MockableQuicClient(server_address, server_id, config, supported_versions,
221                          event_loop, std::move(proof_verifier), nullptr) {}
222 
MockableQuicClient(QuicSocketAddress server_address,const QuicServerId & server_id,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,QuicEventLoop * event_loop,std::unique_ptr<ProofVerifier> proof_verifier,std::unique_ptr<SessionCache> session_cache)223 MockableQuicClient::MockableQuicClient(
224     QuicSocketAddress server_address, const QuicServerId& server_id,
225     const QuicConfig& config, const ParsedQuicVersionVector& supported_versions,
226     QuicEventLoop* event_loop, std::unique_ptr<ProofVerifier> proof_verifier,
227     std::unique_ptr<SessionCache> session_cache)
228     : QuicDefaultClient(
229           server_address, server_id, supported_versions, config, event_loop,
230           std::make_unique<MockableQuicClientDefaultNetworkHelper>(event_loop,
231                                                                    this),
232           std::make_unique<RecordingProofVerifier>(std::move(proof_verifier)),
233           std::move(session_cache)),
234       override_client_connection_id_(EmptyQuicConnectionId()),
235       client_connection_id_overridden_(false) {}
236 
~MockableQuicClient()237 MockableQuicClient::~MockableQuicClient() {
238   if (connected()) {
239     Disconnect();
240   }
241 }
242 
243 MockableQuicClientDefaultNetworkHelper*
mockable_network_helper()244 MockableQuicClient::mockable_network_helper() {
245   return static_cast<MockableQuicClientDefaultNetworkHelper*>(
246       default_network_helper());
247 }
248 
249 const MockableQuicClientDefaultNetworkHelper*
mockable_network_helper() const250 MockableQuicClient::mockable_network_helper() const {
251   return static_cast<const MockableQuicClientDefaultNetworkHelper*>(
252       default_network_helper());
253 }
254 
GetClientConnectionId()255 QuicConnectionId MockableQuicClient::GetClientConnectionId() {
256   if (client_connection_id_overridden_) {
257     return override_client_connection_id_;
258   }
259   if (override_client_connection_id_length_ >= 0) {
260     return QuicUtils::CreateRandomConnectionId(
261         override_client_connection_id_length_);
262   }
263   return QuicDefaultClient::GetClientConnectionId();
264 }
265 
UseClientConnectionId(QuicConnectionId client_connection_id)266 void MockableQuicClient::UseClientConnectionId(
267     QuicConnectionId client_connection_id) {
268   client_connection_id_overridden_ = true;
269   override_client_connection_id_ = client_connection_id;
270 }
271 
UseClientConnectionIdLength(int client_connection_id_length)272 void MockableQuicClient::UseClientConnectionIdLength(
273     int client_connection_id_length) {
274   override_client_connection_id_length_ = client_connection_id_length;
275 }
276 
UseWriter(QuicPacketWriterWrapper * writer)277 void MockableQuicClient::UseWriter(QuicPacketWriterWrapper* writer) {
278   mockable_network_helper()->UseWriter(writer);
279 }
280 
set_peer_address(const QuicSocketAddress & address)281 void MockableQuicClient::set_peer_address(const QuicSocketAddress& address) {
282   mockable_network_helper()->set_peer_address(address);
283   if (client_session() != nullptr) {
284     client_session()->connection()->AddKnownServerAddress(address);
285   }
286 }
287 
last_incoming_packet()288 const QuicReceivedPacket* MockableQuicClient::last_incoming_packet() {
289   return mockable_network_helper()->last_incoming_packet();
290 }
291 
set_track_last_incoming_packet(bool track)292 void MockableQuicClient::set_track_last_incoming_packet(bool track) {
293   mockable_network_helper()->set_track_last_incoming_packet(track);
294 }
295 
QuicTestClient(QuicSocketAddress server_address,const std::string & server_hostname,const ParsedQuicVersionVector & supported_versions)296 QuicTestClient::QuicTestClient(
297     QuicSocketAddress server_address, const std::string& server_hostname,
298     const ParsedQuicVersionVector& supported_versions)
299     : QuicTestClient(server_address, server_hostname, QuicConfig(),
300                      supported_versions) {}
301 
QuicTestClient(QuicSocketAddress server_address,const std::string & server_hostname,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions)302 QuicTestClient::QuicTestClient(
303     QuicSocketAddress server_address, const std::string& server_hostname,
304     const QuicConfig& config, const ParsedQuicVersionVector& supported_versions)
305     : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())),
306       client_(std::make_unique<MockableQuicClient>(
307           server_address,
308           QuicServerId(server_hostname, server_address.port(), false), config,
309           supported_versions, event_loop_.get())) {
310   Initialize();
311 }
312 
QuicTestClient(QuicSocketAddress server_address,const std::string & server_hostname,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,std::unique_ptr<ProofVerifier> proof_verifier)313 QuicTestClient::QuicTestClient(
314     QuicSocketAddress server_address, const std::string& server_hostname,
315     const QuicConfig& config, const ParsedQuicVersionVector& supported_versions,
316     std::unique_ptr<ProofVerifier> proof_verifier)
317     : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())),
318       client_(std::make_unique<MockableQuicClient>(
319           server_address,
320           QuicServerId(server_hostname, server_address.port(), false), config,
321           supported_versions, event_loop_.get(), std::move(proof_verifier))) {
322   Initialize();
323 }
324 
QuicTestClient(QuicSocketAddress server_address,const std::string & server_hostname,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,std::unique_ptr<ProofVerifier> proof_verifier,std::unique_ptr<SessionCache> session_cache)325 QuicTestClient::QuicTestClient(
326     QuicSocketAddress server_address, const std::string& server_hostname,
327     const QuicConfig& config, const ParsedQuicVersionVector& supported_versions,
328     std::unique_ptr<ProofVerifier> proof_verifier,
329     std::unique_ptr<SessionCache> session_cache)
330     : event_loop_(GetDefaultEventLoop()->Create(QuicDefaultClock::Get())),
331       client_(std::make_unique<MockableQuicClient>(
332           server_address,
333           QuicServerId(server_hostname, server_address.port(), false), config,
334           supported_versions, event_loop_.get(), std::move(proof_verifier),
335           std::move(session_cache))) {
336   Initialize();
337 }
338 
QuicTestClient(QuicSocketAddress server_address,const std::string & server_hostname,const QuicConfig & config,const ParsedQuicVersionVector & supported_versions,std::unique_ptr<ProofVerifier> proof_verifier,std::unique_ptr<SessionCache> session_cache,std::unique_ptr<QuicEventLoop> event_loop)339 QuicTestClient::QuicTestClient(
340     QuicSocketAddress server_address, const std::string& server_hostname,
341     const QuicConfig& config, const ParsedQuicVersionVector& supported_versions,
342     std::unique_ptr<ProofVerifier> proof_verifier,
343     std::unique_ptr<SessionCache> session_cache,
344     std::unique_ptr<QuicEventLoop> event_loop)
345     : event_loop_(std::move(event_loop)),
346       client_(std::make_unique<MockableQuicClient>(
347           server_address,
348           QuicServerId(server_hostname, server_address.port(), false), config,
349           supported_versions, event_loop_.get(), std::move(proof_verifier),
350           std::move(session_cache))) {
351   Initialize();
352 }
353 
354 QuicTestClient::QuicTestClient() = default;
355 
~QuicTestClient()356 QuicTestClient::~QuicTestClient() {
357   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
358     stream.second->set_visitor(nullptr);
359   }
360 }
361 
Initialize()362 void QuicTestClient::Initialize() {
363   priority_ = 3;
364   connect_attempted_ = false;
365   auto_reconnect_ = false;
366   buffer_body_ = true;
367   num_requests_ = 0;
368   num_responses_ = 0;
369   ClearPerConnectionState();
370   // As chrome will generally do this, we want it to be the default when it's
371   // not overridden.
372   if (!client_->config()->HasSetBytesForConnectionIdToSend()) {
373     client_->config()->SetBytesForConnectionIdToSend(0);
374   }
375 }
376 
SetUserAgentID(const std::string & user_agent_id)377 void QuicTestClient::SetUserAgentID(const std::string& user_agent_id) {
378   client_->SetUserAgentID(user_agent_id);
379 }
380 
SendRequest(const std::string & uri)381 int64_t QuicTestClient::SendRequest(const std::string& uri) {
382   spdy::Http2HeaderBlock headers;
383   if (!PopulateHeaderBlockFromUrl(uri, &headers)) {
384     return 0;
385   }
386   return SendMessage(headers, "");
387 }
388 
SendRequestAndRstTogether(const std::string & uri)389 int64_t QuicTestClient::SendRequestAndRstTogether(const std::string& uri) {
390   spdy::Http2HeaderBlock headers;
391   if (!PopulateHeaderBlockFromUrl(uri, &headers)) {
392     return 0;
393   }
394 
395   QuicSpdyClientSession* session = client()->client_session();
396   QuicConnection::ScopedPacketFlusher flusher(session->connection());
397   int64_t ret = SendMessage(headers, "", /*fin=*/true, /*flush=*/false);
398 
399   QuicStreamId stream_id = GetNthClientInitiatedBidirectionalStreamId(
400       session->transport_version(), 0);
401   session->ResetStream(stream_id, QUIC_STREAM_CANCELLED);
402   return ret;
403 }
404 
SendRequestsAndWaitForResponses(const std::vector<std::string> & url_list)405 void QuicTestClient::SendRequestsAndWaitForResponses(
406     const std::vector<std::string>& url_list) {
407   for (const std::string& url : url_list) {
408     SendRequest(url);
409   }
410   while (client()->WaitForEvents()) {
411   }
412 }
413 
GetOrCreateStreamAndSendRequest(const spdy::Http2HeaderBlock * headers,absl::string_view body,bool fin,quiche::QuicheReferenceCountedPointer<QuicAckListenerInterface> ack_listener)414 int64_t QuicTestClient::GetOrCreateStreamAndSendRequest(
415     const spdy::Http2HeaderBlock* headers, absl::string_view body, bool fin,
416     quiche::QuicheReferenceCountedPointer<QuicAckListenerInterface>
417         ack_listener) {
418   // Maybe it's better just to overload this.  it's just that we need
419   // for the GetOrCreateStream function to call something else...which
420   // is icky and complicated, but maybe not worse than this.
421   QuicSpdyClientStream* stream = GetOrCreateStream();
422   if (stream == nullptr) {
423     return 0;
424   }
425   QuicSpdyStreamPeer::set_ack_listener(stream, ack_listener);
426 
427   int64_t ret = 0;
428   if (headers != nullptr) {
429     spdy::Http2HeaderBlock spdy_headers(headers->Clone());
430     if (spdy_headers[":authority"].as_string().empty()) {
431       spdy_headers[":authority"] = client_->server_id().host();
432     }
433     ret = stream->SendRequest(std::move(spdy_headers), body, fin);
434     ++num_requests_;
435   } else {
436     stream->WriteOrBufferBody(std::string(body), fin);
437     ret = body.length();
438   }
439   return ret;
440 }
441 
SendMessage(const spdy::Http2HeaderBlock & headers,absl::string_view body)442 int64_t QuicTestClient::SendMessage(const spdy::Http2HeaderBlock& headers,
443                                     absl::string_view body) {
444   return SendMessage(headers, body, /*fin=*/true);
445 }
446 
SendMessage(const spdy::Http2HeaderBlock & headers,absl::string_view body,bool fin)447 int64_t QuicTestClient::SendMessage(const spdy::Http2HeaderBlock& headers,
448                                     absl::string_view body, bool fin) {
449   return SendMessage(headers, body, fin, /*flush=*/true);
450 }
451 
SendMessage(const spdy::Http2HeaderBlock & headers,absl::string_view body,bool fin,bool flush)452 int64_t QuicTestClient::SendMessage(const spdy::Http2HeaderBlock& headers,
453                                     absl::string_view body, bool fin,
454                                     bool flush) {
455   // Always force creation of a stream for SendMessage.
456   latest_created_stream_ = nullptr;
457 
458   int64_t ret = GetOrCreateStreamAndSendRequest(&headers, body, fin, nullptr);
459 
460   if (flush) {
461     WaitForWriteToFlush();
462   }
463   return ret;
464 }
465 
SendData(const std::string & data,bool last_data)466 int64_t QuicTestClient::SendData(const std::string& data, bool last_data) {
467   return SendData(data, last_data, nullptr);
468 }
469 
SendData(const std::string & data,bool last_data,quiche::QuicheReferenceCountedPointer<QuicAckListenerInterface> ack_listener)470 int64_t QuicTestClient::SendData(
471     const std::string& data, bool last_data,
472     quiche::QuicheReferenceCountedPointer<QuicAckListenerInterface>
473         ack_listener) {
474   return GetOrCreateStreamAndSendRequest(nullptr, absl::string_view(data),
475                                          last_data, std::move(ack_listener));
476 }
477 
response_complete() const478 bool QuicTestClient::response_complete() const { return response_complete_; }
479 
response_body_size() const480 int64_t QuicTestClient::response_body_size() const {
481   return response_body_size_;
482 }
483 
buffer_body() const484 bool QuicTestClient::buffer_body() const { return buffer_body_; }
485 
set_buffer_body(bool buffer_body)486 void QuicTestClient::set_buffer_body(bool buffer_body) {
487   buffer_body_ = buffer_body;
488 }
489 
response_body() const490 const std::string& QuicTestClient::response_body() const { return response_; }
491 
SendCustomSynchronousRequest(const spdy::Http2HeaderBlock & headers,const std::string & body)492 std::string QuicTestClient::SendCustomSynchronousRequest(
493     const spdy::Http2HeaderBlock& headers, const std::string& body) {
494   // Clear connection state here and only track this synchronous request.
495   ClearPerConnectionState();
496   if (SendMessage(headers, body) == 0) {
497     QUIC_DLOG(ERROR) << "Failed the request for: " << headers.DebugString();
498     // Set the response_ explicitly.  Otherwise response_ will contain the
499     // response from the previously successful request.
500     response_ = "";
501   } else {
502     WaitForResponse();
503   }
504   return response_;
505 }
506 
SendSynchronousRequest(const std::string & uri)507 std::string QuicTestClient::SendSynchronousRequest(const std::string& uri) {
508   spdy::Http2HeaderBlock headers;
509   if (!PopulateHeaderBlockFromUrl(uri, &headers)) {
510     return "";
511   }
512   return SendCustomSynchronousRequest(headers, "");
513 }
514 
SendConnectivityProbing()515 void QuicTestClient::SendConnectivityProbing() {
516   QuicConnection* connection = client()->client_session()->connection();
517   connection->SendConnectivityProbingPacket(connection->writer(),
518                                             connection->peer_address());
519 }
520 
SetLatestCreatedStream(QuicSpdyClientStream * stream)521 void QuicTestClient::SetLatestCreatedStream(QuicSpdyClientStream* stream) {
522   latest_created_stream_ = stream;
523   if (latest_created_stream_ != nullptr) {
524     open_streams_[stream->id()] = stream;
525     stream->set_visitor(this);
526   }
527 }
528 
GetOrCreateStream()529 QuicSpdyClientStream* QuicTestClient::GetOrCreateStream() {
530   if (!connect_attempted_ || auto_reconnect_) {
531     if (!connected()) {
532       Connect();
533     }
534     if (!connected()) {
535       return nullptr;
536     }
537   }
538   if (open_streams_.empty()) {
539     ClearPerConnectionState();
540   }
541   if (!latest_created_stream_) {
542     SetLatestCreatedStream(client_->CreateClientStream());
543     if (latest_created_stream_) {
544       latest_created_stream_->SetPriority(QuicStreamPriority(
545           HttpStreamPriority{priority_, /* incremental = */ false}));
546     }
547   }
548 
549   return latest_created_stream_;
550 }
551 
connection_error() const552 QuicErrorCode QuicTestClient::connection_error() const {
553   return client()->connection_error();
554 }
555 
cert_common_name() const556 const std::string& QuicTestClient::cert_common_name() const {
557   return reinterpret_cast<RecordingProofVerifier*>(client_->proof_verifier())
558       ->common_name();
559 }
560 
cert_sct() const561 const std::string& QuicTestClient::cert_sct() const {
562   return reinterpret_cast<RecordingProofVerifier*>(client_->proof_verifier())
563       ->cert_sct();
564 }
565 
GetServerConfig() const566 const QuicTagValueMap& QuicTestClient::GetServerConfig() const {
567   QuicCryptoClientConfig* config = client_->crypto_config();
568   const QuicCryptoClientConfig::CachedState* state =
569       config->LookupOrCreate(client_->server_id());
570   const CryptoHandshakeMessage* handshake_msg = state->GetServerConfig();
571   return handshake_msg->tag_value_map();
572 }
573 
connected() const574 bool QuicTestClient::connected() const { return client_->connected(); }
575 
Connect()576 void QuicTestClient::Connect() {
577   if (connected()) {
578     QUIC_BUG(quic_bug_10133_1) << "Cannot connect already-connected client";
579     return;
580   }
581   if (!connect_attempted_) {
582     client_->Initialize();
583   }
584 
585   // If we've been asked to override SNI, set it now
586   if (override_sni_set_) {
587     client_->set_server_id(
588         QuicServerId(override_sni_, address().port(), false));
589   }
590 
591   client_->Connect();
592   connect_attempted_ = true;
593 }
594 
ResetConnection()595 void QuicTestClient::ResetConnection() {
596   Disconnect();
597   Connect();
598 }
599 
Disconnect()600 void QuicTestClient::Disconnect() {
601   ClearPerConnectionState();
602   if (client_->initialized()) {
603     client_->Disconnect();
604   }
605   connect_attempted_ = false;
606 }
607 
local_address() const608 QuicSocketAddress QuicTestClient::local_address() const {
609   return client_->network_helper()->GetLatestClientAddress();
610 }
611 
ClearPerRequestState()612 void QuicTestClient::ClearPerRequestState() {
613   stream_error_ = QUIC_STREAM_NO_ERROR;
614   response_ = "";
615   response_complete_ = false;
616   response_headers_complete_ = false;
617   response_headers_.clear();
618   response_trailers_.clear();
619   bytes_read_ = 0;
620   bytes_written_ = 0;
621   response_body_size_ = 0;
622 }
623 
HaveActiveStream()624 bool QuicTestClient::HaveActiveStream() { return !open_streams_.empty(); }
625 
WaitUntil(int timeout_ms,std::optional<quiche::UnretainedCallback<bool ()>> trigger)626 bool QuicTestClient::WaitUntil(
627     int timeout_ms, std::optional<quiche::UnretainedCallback<bool()>> trigger) {
628   QuicTime::Delta timeout = QuicTime::Delta::FromMilliseconds(timeout_ms);
629   const QuicClock* clock = client()->session()->connection()->clock();
630   QuicTime end_waiting_time = clock->Now() + timeout;
631   while (connected() && !(trigger.has_value() && (*trigger)()) &&
632          (timeout_ms < 0 || clock->Now() < end_waiting_time)) {
633     event_loop_->RunEventLoopOnce(timeout);
634     client_->WaitForEventsPostprocessing();
635   }
636   ReadNextResponse();
637   if (trigger.has_value() && !(*trigger)()) {
638     QUIC_VLOG(1) << "Client WaitUntil returning with trigger returning false.";
639     return false;
640   }
641   return true;
642 }
643 
Send(absl::string_view data)644 int64_t QuicTestClient::Send(absl::string_view data) {
645   return SendData(std::string(data), false);
646 }
647 
response_headers_complete() const648 bool QuicTestClient::response_headers_complete() const {
649   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
650     if (stream.second->headers_decompressed()) {
651       return true;
652     }
653   }
654   return response_headers_complete_;
655 }
656 
response_headers() const657 const spdy::Http2HeaderBlock* QuicTestClient::response_headers() const {
658   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
659     if (stream.second->headers_decompressed()) {
660       response_headers_ = stream.second->response_headers().Clone();
661       break;
662     }
663   }
664   return &response_headers_;
665 }
666 
response_trailers() const667 const spdy::Http2HeaderBlock& QuicTestClient::response_trailers() const {
668   return response_trailers_;
669 }
670 
response_size() const671 int64_t QuicTestClient::response_size() const { return bytes_read(); }
672 
bytes_read() const673 size_t QuicTestClient::bytes_read() const {
674   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
675     size_t bytes_read = stream.second->total_body_bytes_read() +
676                         stream.second->header_bytes_read();
677     if (bytes_read > 0) {
678       return bytes_read;
679     }
680   }
681   return bytes_read_;
682 }
683 
bytes_written() const684 size_t QuicTestClient::bytes_written() const {
685   for (std::pair<QuicStreamId, QuicSpdyClientStream*> stream : open_streams_) {
686     size_t bytes_written = stream.second->stream_bytes_written() +
687                            stream.second->header_bytes_written();
688     if (bytes_written > 0) {
689       return bytes_written;
690     }
691   }
692   return bytes_written_;
693 }
694 
partial_response_body() const695 absl::string_view QuicTestClient::partial_response_body() const {
696   return latest_created_stream_ == nullptr ? ""
697                                            : latest_created_stream_->data();
698 }
699 
OnClose(QuicSpdyStream * stream)700 void QuicTestClient::OnClose(QuicSpdyStream* stream) {
701   if (stream == nullptr) {
702     return;
703   }
704   // Always close the stream, regardless of whether it was the last stream
705   // written.
706   client()->OnClose(stream);
707   ++num_responses_;
708   if (open_streams_.find(stream->id()) == open_streams_.end()) {
709     return;
710   }
711   if (latest_created_stream_ == stream) {
712     latest_created_stream_ = nullptr;
713   }
714   QuicSpdyClientStream* client_stream =
715       static_cast<QuicSpdyClientStream*>(stream);
716   QuicStreamId id = client_stream->id();
717   closed_stream_states_.insert(std::make_pair(
718       id,
719       PerStreamState(
720           // Set response_complete to true iff stream is closed while connected.
721           client_stream->stream_error(), connected(),
722           client_stream->headers_decompressed(),
723           client_stream->response_headers(),
724           (buffer_body() ? std::string(client_stream->data()) : ""),
725           client_stream->received_trailers(),
726           // Use NumBytesConsumed to avoid counting retransmitted stream frames.
727           client_stream->total_body_bytes_read() +
728               client_stream->header_bytes_read(),
729           client_stream->stream_bytes_written() +
730               client_stream->header_bytes_written(),
731           client_stream->data().size())));
732   open_streams_.erase(id);
733 }
734 
UseWriter(QuicPacketWriterWrapper * writer)735 void QuicTestClient::UseWriter(QuicPacketWriterWrapper* writer) {
736   client_->UseWriter(writer);
737 }
738 
UseConnectionId(QuicConnectionId server_connection_id)739 void QuicTestClient::UseConnectionId(QuicConnectionId server_connection_id) {
740   QUICHE_DCHECK(!connected());
741   client_->set_server_connection_id_override(server_connection_id);
742 }
743 
UseConnectionIdLength(uint8_t server_connection_id_length)744 void QuicTestClient::UseConnectionIdLength(
745     uint8_t server_connection_id_length) {
746   QUICHE_DCHECK(!connected());
747   client_->set_server_connection_id_length(server_connection_id_length);
748 }
749 
UseClientConnectionId(QuicConnectionId client_connection_id)750 void QuicTestClient::UseClientConnectionId(
751     QuicConnectionId client_connection_id) {
752   QUICHE_DCHECK(!connected());
753   client_->UseClientConnectionId(client_connection_id);
754 }
755 
UseClientConnectionIdLength(uint8_t client_connection_id_length)756 void QuicTestClient::UseClientConnectionIdLength(
757     uint8_t client_connection_id_length) {
758   QUICHE_DCHECK(!connected());
759   client_->UseClientConnectionIdLength(client_connection_id_length);
760 }
761 
MigrateSocket(const QuicIpAddress & new_host)762 bool QuicTestClient::MigrateSocket(const QuicIpAddress& new_host) {
763   return client_->MigrateSocket(new_host);
764 }
765 
MigrateSocketWithSpecifiedPort(const QuicIpAddress & new_host,int port)766 bool QuicTestClient::MigrateSocketWithSpecifiedPort(
767     const QuicIpAddress& new_host, int port) {
768   client_->set_local_port(port);
769   return client_->MigrateSocket(new_host);
770 }
771 
bind_to_address() const772 QuicIpAddress QuicTestClient::bind_to_address() const {
773   return client_->bind_to_address();
774 }
775 
set_bind_to_address(QuicIpAddress address)776 void QuicTestClient::set_bind_to_address(QuicIpAddress address) {
777   client_->set_bind_to_address(address);
778 }
779 
address() const780 const QuicSocketAddress& QuicTestClient::address() const {
781   return client_->server_address();
782 }
783 
WaitForWriteToFlush()784 void QuicTestClient::WaitForWriteToFlush() {
785   while (connected() && client()->session()->HasDataToWrite()) {
786     client_->WaitForEvents();
787   }
788 }
789 
PerStreamState(const PerStreamState & other)790 QuicTestClient::PerStreamState::PerStreamState(const PerStreamState& other)
791     : stream_error(other.stream_error),
792       response_complete(other.response_complete),
793       response_headers_complete(other.response_headers_complete),
794       response_headers(other.response_headers.Clone()),
795       response(other.response),
796       response_trailers(other.response_trailers.Clone()),
797       bytes_read(other.bytes_read),
798       bytes_written(other.bytes_written),
799       response_body_size(other.response_body_size) {}
800 
PerStreamState(QuicRstStreamErrorCode stream_error,bool response_complete,bool response_headers_complete,const spdy::Http2HeaderBlock & response_headers,const std::string & response,const spdy::Http2HeaderBlock & response_trailers,uint64_t bytes_read,uint64_t bytes_written,int64_t response_body_size)801 QuicTestClient::PerStreamState::PerStreamState(
802     QuicRstStreamErrorCode stream_error, bool response_complete,
803     bool response_headers_complete,
804     const spdy::Http2HeaderBlock& response_headers, const std::string& response,
805     const spdy::Http2HeaderBlock& response_trailers, uint64_t bytes_read,
806     uint64_t bytes_written, int64_t response_body_size)
807     : stream_error(stream_error),
808       response_complete(response_complete),
809       response_headers_complete(response_headers_complete),
810       response_headers(response_headers.Clone()),
811       response(response),
812       response_trailers(response_trailers.Clone()),
813       bytes_read(bytes_read),
814       bytes_written(bytes_written),
815       response_body_size(response_body_size) {}
816 
817 QuicTestClient::PerStreamState::~PerStreamState() = default;
818 
PopulateHeaderBlockFromUrl(const std::string & uri,spdy::Http2HeaderBlock * headers)819 bool QuicTestClient::PopulateHeaderBlockFromUrl(
820     const std::string& uri, spdy::Http2HeaderBlock* headers) {
821   std::string url;
822   if (absl::StartsWith(uri, "https://") || absl::StartsWith(uri, "http://")) {
823     url = uri;
824   } else if (uri[0] == '/') {
825     url = "https://" + client_->server_id().host() + uri;
826   } else {
827     url = "https://" + uri;
828   }
829   return SpdyUtils::PopulateHeaderBlockFromUrl(url, headers);
830 }
831 
ReadNextResponse()832 void QuicTestClient::ReadNextResponse() {
833   if (closed_stream_states_.empty()) {
834     return;
835   }
836 
837   PerStreamState state(closed_stream_states_.front().second);
838 
839   stream_error_ = state.stream_error;
840   response_ = state.response;
841   response_complete_ = state.response_complete;
842   response_headers_complete_ = state.response_headers_complete;
843   response_headers_ = state.response_headers.Clone();
844   response_trailers_ = state.response_trailers.Clone();
845   bytes_read_ = state.bytes_read;
846   bytes_written_ = state.bytes_written;
847   response_body_size_ = state.response_body_size;
848 
849   closed_stream_states_.pop_front();
850 }
851 
ClearPerConnectionState()852 void QuicTestClient::ClearPerConnectionState() {
853   ClearPerRequestState();
854   open_streams_.clear();
855   closed_stream_states_.clear();
856   latest_created_stream_ = nullptr;
857 }
858 
WaitForDelayedAcks()859 void QuicTestClient::WaitForDelayedAcks() {
860   // kWaitDuration is a period of time that is long enough for all delayed
861   // acks to be sent and received on the other end.
862   const QuicTime::Delta kWaitDuration =
863       4 * QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs);
864 
865   const QuicClock* clock = client()->client_session()->connection()->clock();
866 
867   QuicTime wait_until = clock->ApproximateNow() + kWaitDuration;
868   while (connected() && clock->ApproximateNow() < wait_until) {
869     // This waits for up to 50 ms.
870     client()->WaitForEvents();
871   }
872 }
873 
874 }  // namespace test
875 }  // namespace quic
876