xref: /aosp_15_r20/external/cronet/net/websockets/websocket_handshake_stream_create_helper_test.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
6 
7 #include <optional>
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 #include "base/containers/span.h"
13 #include "base/functional/callback.h"
14 #include "base/memory/scoped_refptr.h"
15 #include "base/notreached.h"
16 #include "base/strings/string_piece.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "base/time/default_tick_clock.h"
19 #include "base/time/time.h"
20 #include "net/base/auth.h"
21 #include "net/base/completion_once_callback.h"
22 #include "net/base/connection_endpoint_metadata.h"
23 #include "net/base/host_port_pair.h"
24 #include "net/base/ip_address.h"
25 #include "net/base/ip_endpoint.h"
26 #include "net/base/load_flags.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/network_anonymization_key.h"
29 #include "net/base/network_handle.h"
30 #include "net/base/privacy_mode.h"
31 #include "net/base/proxy_chain.h"
32 #include "net/base/proxy_server.h"
33 #include "net/base/request_priority.h"
34 #include "net/base/session_usage.h"
35 #include "net/base/test_completion_callback.h"
36 #include "net/cert/cert_verify_result.h"
37 #include "net/dns/public/host_resolver_results.h"
38 #include "net/dns/public/secure_dns_policy.h"
39 #include "net/http/http_request_info.h"
40 #include "net/http/http_response_headers.h"
41 #include "net/http/http_response_info.h"
42 #include "net/http/transport_security_state.h"
43 #include "net/log/net_log.h"
44 #include "net/log/net_log_with_source.h"
45 #include "net/quic/address_utils.h"
46 #include "net/quic/crypto/proof_verifier_chromium.h"
47 #include "net/quic/mock_crypto_client_stream_factory.h"
48 #include "net/quic/mock_quic_data.h"
49 #include "net/quic/quic_chromium_alarm_factory.h"
50 #include "net/quic/quic_chromium_connection_helper.h"
51 #include "net/quic/quic_chromium_packet_reader.h"
52 #include "net/quic/quic_chromium_packet_writer.h"
53 #include "net/quic/quic_context.h"
54 #include "net/quic/quic_http_utils.h"
55 #include "net/quic/quic_server_info.h"
56 #include "net/quic/quic_session_key.h"
57 #include "net/quic/quic_test_packet_maker.h"
58 #include "net/quic/test_quic_crypto_client_config_handle.h"
59 #include "net/quic/test_task_runner.h"
60 #include "net/socket/client_socket_handle.h"
61 #include "net/socket/client_socket_pool.h"
62 #include "net/socket/connect_job.h"
63 #include "net/socket/socket_tag.h"
64 #include "net/socket/socket_test_util.h"
65 #include "net/socket/websocket_endpoint_lock_manager.h"
66 #include "net/spdy/spdy_session_key.h"
67 #include "net/spdy/spdy_test_util_common.h"
68 #include "net/ssl/ssl_config_service_defaults.h"
69 #include "net/ssl/ssl_info.h"
70 #include "net/test/cert_test_util.h"
71 #include "net/test/gtest_util.h"
72 #include "net/test/test_data_directory.h"
73 #include "net/test/test_with_task_environment.h"
74 #include "net/third_party/quiche/src/quiche/common/platform/api/quiche_flags.h"
75 #include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_client_config.h"
76 #include "net/third_party/quiche/src/quiche/quic/core/qpack/qpack_decoder.h"
77 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection.h"
78 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection_id.h"
79 #include "net/third_party/quiche/src/quiche/quic/core/quic_error_codes.h"
80 #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
81 #include "net/third_party/quiche/src/quiche/quic/core/quic_time.h"
82 #include "net/third_party/quiche/src/quiche/quic/core/quic_types.h"
83 #include "net/third_party/quiche/src/quiche/quic/core/quic_utils.h"
84 #include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
85 #include "net/third_party/quiche/src/quiche/quic/platform/api/quic_socket_address.h"
86 #include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h"
87 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_clock.h"
88 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_connection_id_generator.h"
89 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_random.h"
90 #include "net/third_party/quiche/src/quiche/quic/test_tools/qpack/qpack_test_utils.h"
91 #include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h"
92 #include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h"
93 #include "net/third_party/quiche/src/quiche/spdy/core/spdy_protocol.h"
94 #include "net/traffic_annotation/network_traffic_annotation.h"
95 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
96 #include "net/websockets/websocket_basic_handshake_stream.h"
97 #include "net/websockets/websocket_event_interface.h"
98 #include "net/websockets/websocket_stream.h"
99 #include "net/websockets/websocket_test_util.h"
100 #include "testing/gmock/include/gmock/gmock.h"
101 #include "testing/gtest/include/gtest/gtest.h"
102 #include "url/gurl.h"
103 #include "url/origin.h"
104 #include "url/scheme_host_port.h"
105 #include "url/url_constants.h"
106 
107 namespace net {
108 class HttpNetworkSession;
109 class URLRequest;
110 class WebSocketHttp2HandshakeStream;
111 class WebSocketHttp3HandshakeStream;
112 class X509Certificate;
113 struct WebSocketHandshakeRequestInfo;
114 struct WebSocketHandshakeResponseInfo;
115 }  // namespace net
116 
117 using ::net::test::IsError;
118 using ::net::test::IsOk;
119 using ::testing::_;
120 using ::testing::StrictMock;
121 using ::testing::TestWithParam;
122 using ::testing::Values;
123 
124 namespace net {
125 namespace {
126 
127 enum HandshakeStreamType {
128   BASIC_HANDSHAKE_STREAM,
129   HTTP2_HANDSHAKE_STREAM,
130   HTTP3_HANDSHAKE_STREAM
131 };
132 
133 // This class encapsulates the details of creating a mock ClientSocketHandle.
134 class MockClientSocketHandleFactory {
135  public:
MockClientSocketHandleFactory()136   MockClientSocketHandleFactory()
137       : common_connect_job_params_(
138             socket_factory_maker_.factory(),
139             /*host_resolver=*/nullptr,
140             /*http_auth_cache=*/nullptr,
141             /*http_auth_handler_factory=*/nullptr,
142             /*spdy_session_pool=*/nullptr,
143             /*quic_supported_versions=*/nullptr,
144             /*quic_session_pool=*/nullptr,
145             /*proxy_delegate=*/nullptr,
146             /*http_user_agent_settings=*/nullptr,
147             /*ssl_client_context=*/nullptr,
148             /*socket_performance_watcher_factory=*/nullptr,
149             /*network_quality_estimator=*/nullptr,
150             /*net_log=*/nullptr,
151             /*websocket_endpoint_lock_manager=*/nullptr,
152             /*http_server_properties=*/nullptr,
153             /*alpn_protos=*/nullptr,
154             /*application_settings=*/nullptr,
155             /*ignore_certificate_errors=*/nullptr,
156             /*early_data_enabled=*/nullptr),
157         pool_(1, 1, &common_connect_job_params_) {}
158 
159   MockClientSocketHandleFactory(const MockClientSocketHandleFactory&) = delete;
160   MockClientSocketHandleFactory& operator=(
161       const MockClientSocketHandleFactory&) = delete;
162 
163   // The created socket expects |expect_written| to be written to the socket,
164   // and will respond with |return_to_read|. The test will fail if the expected
165   // text is not written, or if all the bytes are not read.
CreateClientSocketHandle(const std::string & expect_written,const std::string & return_to_read)166   std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
167       const std::string& expect_written,
168       const std::string& return_to_read) {
169     socket_factory_maker_.SetExpectations(expect_written, return_to_read);
170     auto socket_handle = std::make_unique<ClientSocketHandle>();
171     socket_handle->Init(
172         ClientSocketPool::GroupId(
173             url::SchemeHostPort(url::kHttpScheme, "a", 80),
174             PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(),
175             SecureDnsPolicy::kAllow, /*disable_cert_network_fetches=*/false),
176         scoped_refptr<ClientSocketPool::SocketParams>(),
177         std::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
178         ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
179         ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
180     return socket_handle;
181   }
182 
183  private:
184   WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
185   const CommonConnectJobParams common_connect_job_params_;
186   MockTransportClientSocketPool pool_;
187 };
188 
189 class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
190  public:
191   ~TestConnectDelegate() override = default;
192 
OnCreateRequest(URLRequest * request)193   void OnCreateRequest(URLRequest* request) override {}
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)194   void OnURLRequestConnected(URLRequest* request,
195                              const TransportInfo& info) override {}
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)196   void OnSuccess(
197       std::unique_ptr<WebSocketStream> stream,
198       std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
OnFailure(const std::string & failure_message,int net_error,std::optional<int> response_code)199   void OnFailure(const std::string& failure_message,
200                  int net_error,
201                  std::optional<int> response_code) override {}
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)202   void OnStartOpeningHandshake(
203       std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)204   void OnSSLCertificateError(
205       std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
206           ssl_error_callbacks,
207       int net_error,
208       const SSLInfo& ssl_info,
209       bool fatal) override {}
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)210   int OnAuthRequired(const AuthChallengeInfo& auth_info,
211                      scoped_refptr<HttpResponseHeaders> response_headers,
212                      const IPEndPoint& host_port_pair,
213                      base::OnceCallback<void(const AuthCredentials*)> callback,
214                      std::optional<AuthCredentials>* credentials) override {
215     *credentials = std::nullopt;
216     return OK;
217   }
218 };
219 
220 class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
221  public:
222   ~MockWebSocketStreamRequestAPI() override = default;
223 
224   MOCK_METHOD1(OnBasicHandshakeStreamCreated,
225                void(WebSocketBasicHandshakeStream* handshake_stream));
226   MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
227                void(WebSocketHttp2HandshakeStream* handshake_stream));
228   MOCK_METHOD1(OnHttp3HandshakeStreamCreated,
229                void(WebSocketHttp3HandshakeStream* handshake_stream));
230   MOCK_METHOD3(OnFailure,
231                void(const std::string& message,
232                     int net_error,
233                     std::optional<int> response_code));
234 };
235 
236 class WebSocketHandshakeStreamCreateHelperTest
237     : public TestWithParam<HandshakeStreamType>,
238       public WithTaskEnvironment {
239  protected:
WebSocketHandshakeStreamCreateHelperTest()240   WebSocketHandshakeStreamCreateHelperTest()
241       : quic_version_(quic::HandshakeProtocol::PROTOCOL_TLS1_3,
242                       quic::QuicTransportVersion::QUIC_VERSION_IETF_RFC_V1),
243         mock_quic_data_(quic_version_) {}
CreateAndInitializeStream(const std::vector<std::string> & sub_protocols,const WebSocketExtraHeaders & extra_request_headers,const WebSocketExtraHeaders & extra_response_headers)244   std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
245       const std::vector<std::string>& sub_protocols,
246       const WebSocketExtraHeaders& extra_request_headers,
247       const WebSocketExtraHeaders& extra_response_headers) {
248     constexpr char kPath[] = "/";
249     constexpr char kOrigin[] = "http://origin.example.org";
250     const GURL url("wss://www.example.org/");
251     NetLogWithSource net_log;
252 
253     WebSocketHandshakeStreamCreateHelper create_helper(
254         &connect_delegate_, sub_protocols, &stream_request_);
255 
256     switch (GetParam()) {
257       case BASIC_HANDSHAKE_STREAM:
258         EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
259         break;
260 
261       case HTTP2_HANDSHAKE_STREAM:
262         EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
263         break;
264 
265       case HTTP3_HANDSHAKE_STREAM:
266         EXPECT_CALL(stream_request_, OnHttp3HandshakeStreamCreated(_)).Times(1);
267         break;
268 
269       default:
270         NOTREACHED();
271     }
272 
273     EXPECT_CALL(stream_request_, OnFailure(_, _, _)).Times(0);
274 
275     HttpRequestInfo request_info;
276     request_info.url = url;
277     request_info.method = "GET";
278     request_info.load_flags = LOAD_DISABLE_CACHE;
279     request_info.traffic_annotation =
280         MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
281 
282     auto headers = WebSocketCommonTestHeaders();
283 
284     switch (GetParam()) {
285       case BASIC_HANDSHAKE_STREAM: {
286         std::unique_ptr<ClientSocketHandle> socket_handle =
287             socket_handle_factory_.CreateClientSocketHandle(
288                 WebSocketStandardRequest(kPath, "www.example.org",
289                                          url::Origin::Create(GURL(kOrigin)),
290                                          /*send_additional_request_headers=*/{},
291                                          extra_request_headers),
292                 WebSocketStandardResponse(
293                     WebSocketExtraHeadersToString(extra_response_headers)));
294 
295         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
296             create_helper.CreateBasicStream(std::move(socket_handle), false,
297                                             &websocket_endpoint_lock_manager_);
298 
299         // If in future the implementation type returned by CreateBasicStream()
300         // changes, this static_cast will be wrong. However, in that case the
301         // test will fail and AddressSanitizer should identify the issue.
302         static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
303             ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
304 
305         handshake->RegisterRequest(&request_info);
306         int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
307                                              CompletionOnceCallback());
308         EXPECT_THAT(rv, IsOk());
309 
310         HttpResponseInfo response;
311         TestCompletionCallback request_callback;
312         rv = handshake->SendRequest(headers, &response,
313                                     request_callback.callback());
314         EXPECT_THAT(rv, IsOk());
315 
316         TestCompletionCallback response_callback;
317         rv = handshake->ReadResponseHeaders(response_callback.callback());
318         EXPECT_THAT(rv, IsOk());
319         EXPECT_EQ(101, response.headers->response_code());
320         EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
321         EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
322         return handshake->Upgrade();
323       }
324       case HTTP2_HANDSHAKE_STREAM: {
325         SpdyTestUtil spdy_util;
326         spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
327             kPath, "www.example.org", kOrigin, extra_request_headers);
328         spdy::SpdySerializedFrame request_headers(
329             spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
330                                            DEFAULT_PRIORITY, false));
331         MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
332 
333         spdy::Http2HeaderBlock response_header_block =
334             WebSocketHttp2Response(extra_response_headers);
335         spdy::SpdySerializedFrame response_headers(
336             spdy_util.ConstructSpdyResponseHeaders(
337                 1, std::move(response_header_block), false));
338         MockRead reads[] = {CreateMockRead(response_headers, 1),
339                             MockRead(ASYNC, 0, 2)};
340 
341         SequencedSocketData data(reads, writes);
342 
343         SSLSocketDataProvider ssl(ASYNC, OK);
344         ssl.ssl_info.cert =
345             ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
346 
347         SpdySessionDependencies session_deps;
348         session_deps.socket_factory->AddSocketDataProvider(&data);
349         session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
350 
351         std::unique_ptr<HttpNetworkSession> http_network_session =
352             SpdySessionDependencies::SpdyCreateSession(&session_deps);
353         const SpdySessionKey key(
354             HostPortPair::FromURL(url), PRIVACY_MODE_DISABLED,
355             ProxyChain::Direct(), SessionUsage::kDestination, SocketTag(),
356             NetworkAnonymizationKey(), SecureDnsPolicy::kAllow,
357             /*disable_cert_verification_network_fetches=*/false);
358         base::WeakPtr<SpdySession> spdy_session =
359             CreateSpdySession(http_network_session.get(), key, net_log);
360         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
361             create_helper.CreateHttp2Stream(spdy_session, {} /* dns_aliases */);
362 
363         handshake->RegisterRequest(&request_info);
364         int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY,
365                                              NetLogWithSource(),
366                                              CompletionOnceCallback());
367         EXPECT_THAT(rv, IsOk());
368 
369         HttpResponseInfo response;
370         TestCompletionCallback request_callback;
371         rv = handshake->SendRequest(headers, &response,
372                                     request_callback.callback());
373         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
374         rv = request_callback.WaitForResult();
375         EXPECT_THAT(rv, IsOk());
376 
377         TestCompletionCallback response_callback;
378         rv = handshake->ReadResponseHeaders(response_callback.callback());
379         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
380         rv = response_callback.WaitForResult();
381         EXPECT_THAT(rv, IsOk());
382 
383         EXPECT_EQ(200, response.headers->response_code());
384         return handshake->Upgrade();
385       }
386       case HTTP3_HANDSHAKE_STREAM: {
387         const quic::QuicStreamId client_data_stream_id(
388             quic::QuicUtils::GetFirstBidirectionalStreamId(
389                 quic_version_.transport_version, quic::Perspective::IS_CLIENT));
390         quic::QuicCryptoClientConfig crypto_config(
391             quic::test::crypto_test_utils::ProofVerifierForTesting());
392 
393         const quic::QuicConnectionId connection_id(
394             quic::test::TestConnectionId(2));
395         test::QuicTestPacketMaker client_maker(
396             quic_version_, connection_id, &clock_, "mail.example.org",
397             quic::Perspective::IS_CLIENT,
398             /*client_headers_include_h2_stream_dependency_=*/false);
399         test::QuicTestPacketMaker server_maker(
400             quic_version_, connection_id, &clock_, "mail.example.org",
401             quic::Perspective::IS_SERVER,
402             /*client_headers_include_h2_stream_dependency_=*/false);
403         IPEndPoint peer_addr(IPAddress(192, 0, 2, 23), 443);
404         quic::test::MockConnectionIdGenerator connection_id_generator;
405 
406         testing::StrictMock<quic::test::MockQuicConnectionVisitor> visitor;
407         ProofVerifyDetailsChromium verify_details;
408         MockCryptoClientStreamFactory crypto_client_stream_factory;
409         TransportSecurityState transport_security_state;
410         SSLConfigServiceDefaults ssl_config_service;
411 
412         FLAGS_quic_enable_http3_grease_randomness = false;
413         clock_.AdvanceTime(quic::QuicTime::Delta::FromMilliseconds(20));
414         quic::QuicEnableVersion(quic_version_);
415         quic::test::MockRandom random_generator{0};
416 
417         spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
418             kPath, "www.example.org", kOrigin, extra_request_headers);
419 
420         int packet_number = 1;
421         mock_quic_data_.AddWrite(
422             SYNCHRONOUS,
423             client_maker.MakeInitialSettingsPacket(packet_number++));
424 
425         mock_quic_data_.AddWrite(
426             ASYNC,
427             client_maker.MakeRequestHeadersPacket(
428                 packet_number++, client_data_stream_id,
429                 /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
430                 std::move(request_header_block), nullptr));
431 
432         spdy::Http2HeaderBlock response_header_block =
433             WebSocketHttp2Response(extra_response_headers);
434 
435         mock_quic_data_.AddRead(
436             ASYNC, server_maker.MakeResponseHeadersPacket(
437                        /*packet_number=*/1, client_data_stream_id,
438                        /*fin=*/false, std::move(response_header_block),
439                        /*spdy_headers_frame_length=*/nullptr));
440 
441         mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
442 
443         mock_quic_data_.AddWrite(SYNCHRONOUS,
444                                  client_maker.MakeAckAndRstPacket(
445                                      packet_number++, client_data_stream_id,
446                                      quic::QUIC_STREAM_CANCELLED, 1, 0,
447                                      /*include_stop_sending_if_v99=*/true));
448         auto socket = std::make_unique<MockUDPClientSocket>(
449             mock_quic_data_.InitializeAndGetSequencedSocketData(),
450             NetLog::Get());
451         socket->Connect(peer_addr);
452 
453         scoped_refptr<test::TestTaskRunner> runner =
454             base::MakeRefCounted<test::TestTaskRunner>(&clock_);
455         auto helper = std::make_unique<QuicChromiumConnectionHelper>(
456             &clock_, &random_generator);
457         auto alarm_factory =
458             std::make_unique<QuicChromiumAlarmFactory>(runner.get(), &clock_);
459         // Ownership of 'writer' is passed to 'QuicConnection'.
460         QuicChromiumPacketWriter* writer = new QuicChromiumPacketWriter(
461             socket.get(),
462             base::SingleThreadTaskRunner::GetCurrentDefault().get());
463         quic::QuicConnection* connection = new quic::QuicConnection(
464             connection_id, quic::QuicSocketAddress(),
465             net::ToQuicSocketAddress(peer_addr), helper.get(),
466             alarm_factory.get(), writer, true /* owns_writer */,
467             quic::Perspective::IS_CLIENT,
468             quic::test::SupportedVersions(quic_version_),
469             connection_id_generator);
470         connection->set_visitor(&visitor);
471 
472         // Load a certificate that is valid for *.example.org
473         scoped_refptr<X509Certificate> test_cert(
474             ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem"));
475         EXPECT_TRUE(test_cert.get());
476 
477         verify_details.cert_verify_result.verified_cert = test_cert;
478         verify_details.cert_verify_result.is_issued_by_known_root = true;
479         crypto_client_stream_factory.AddProofVerifyDetails(&verify_details);
480 
481         base::TimeTicks dns_end = base::TimeTicks::Now();
482         base::TimeTicks dns_start = dns_end - base::Milliseconds(1);
483 
484         session_ = std::make_unique<QuicChromiumClientSession>(
485             connection, std::move(socket),
486             /*stream_factory=*/nullptr, &crypto_client_stream_factory, &clock_,
487             &transport_security_state, &ssl_config_service,
488             /*server_info=*/nullptr,
489             QuicSessionKey("mail.example.org", 80, PRIVACY_MODE_DISABLED,
490                            ProxyChain::Direct(), SessionUsage::kDestination,
491                            SocketTag(), NetworkAnonymizationKey(),
492                            SecureDnsPolicy::kAllow,
493                            /*require_dns_https_alpn=*/false),
494             /*require_confirmation=*/false,
495             /*migrate_session_early_v2=*/false,
496             /*migrate_session_on_network_change_v2=*/false,
497             /*default_network=*/handles::kInvalidNetworkHandle,
498             quic::QuicTime::Delta::FromMilliseconds(
499                 kDefaultRetransmittableOnWireTimeout.InMilliseconds()),
500             /*migrate_idle_session=*/true, /*allow_port_migration=*/false,
501             kDefaultIdleSessionMigrationPeriod,
502             /*multi_port_probing_interval=*/0, kMaxTimeOnNonDefaultNetwork,
503             kMaxMigrationsToNonDefaultNetworkOnWriteError,
504             kMaxMigrationsToNonDefaultNetworkOnPathDegrading,
505             kQuicYieldAfterPacketsRead,
506             quic::QuicTime::Delta::FromMilliseconds(
507                 kQuicYieldAfterDurationMilliseconds),
508             /*cert_verify_flags=*/0, quic::test::DefaultQuicConfig(),
509             std::make_unique<TestQuicCryptoClientConfigHandle>(&crypto_config),
510             "CONNECTION_UNKNOWN", dns_start, dns_end,
511             base::DefaultTickClock::GetInstance(),
512             base::SingleThreadTaskRunner::GetCurrentDefault().get(),
513             /*socket_performance_watcher=*/nullptr,
514             ConnectionEndpointMetadata(),
515             NetLogWithSource::Make(NetLogSourceType::NONE));
516 
517         session_->Initialize();
518 
519         // Blackhole QPACK decoder stream instead of constructing mock writes.
520         session_->qpack_decoder()->set_qpack_stream_sender_delegate(
521             &noop_qpack_stream_sender_delegate_);
522         TestCompletionCallback callback;
523         EXPECT_THAT(session_->CryptoConnect(callback.callback()), IsOk());
524         EXPECT_TRUE(session_->OneRttKeysAvailable());
525         std::unique_ptr<QuicChromiumClientSession::Handle> session_handle =
526             session_->CreateHandle(
527                 url::SchemeHostPort(url::kHttpsScheme, "mail.example.org", 80));
528 
529         std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
530             create_helper.CreateHttp3Stream(std::move(session_handle),
531                                             {} /* dns_aliases */);
532 
533         handshake->RegisterRequest(&request_info);
534         int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
535                                              CompletionOnceCallback());
536         EXPECT_THAT(rv, IsOk());
537 
538         HttpResponseInfo response;
539         TestCompletionCallback request_callback;
540         rv = handshake->SendRequest(headers, &response,
541                                     request_callback.callback());
542         EXPECT_THAT(rv, IsOk());
543 
544         session_->StartReading();
545 
546         TestCompletionCallback response_callback;
547         rv = handshake->ReadResponseHeaders(response_callback.callback());
548         EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
549         rv = response_callback.WaitForResult();
550         EXPECT_THAT(rv, IsOk());
551 
552         EXPECT_EQ(200, response.headers->response_code());
553 
554         return handshake->Upgrade();
555       }
556       default:
557         NOTREACHED();
558         return nullptr;
559     }
560   }
561 
562  private:
563   MockClientSocketHandleFactory socket_handle_factory_;
564   TestConnectDelegate connect_delegate_;
565   StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
566   WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
567 
568   // For HTTP3_HANDSHAKE_STREAM
569   quic::ParsedQuicVersion quic_version_;
570   quic::MockClock clock_;
571   std::unique_ptr<QuicChromiumClientSession> session_;
572   test::MockQuicData mock_quic_data_;
573   quic::test::NoopQpackStreamSenderDelegate noop_qpack_stream_sender_delegate_;
574 };
575 
576 INSTANTIATE_TEST_SUITE_P(All,
577                          WebSocketHandshakeStreamCreateHelperTest,
578                          Values(BASIC_HANDSHAKE_STREAM,
579                                 HTTP2_HANDSHAKE_STREAM,
580                                 HTTP3_HANDSHAKE_STREAM));
581 
582 // Confirm that the basic case works as expected.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,BasicStream)583 TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
584   std::unique_ptr<WebSocketStream> stream =
585       CreateAndInitializeStream({}, {}, {});
586   EXPECT_EQ("", stream->GetExtensions());
587   EXPECT_EQ("", stream->GetSubProtocol());
588 }
589 
590 // Verify that the sub-protocols are passed through.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,SubProtocols)591 TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
592   std::vector<std::string> sub_protocols;
593   sub_protocols.push_back("chat");
594   sub_protocols.push_back("superchat");
595   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
596       sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
597       {{"Sec-WebSocket-Protocol", "superchat"}});
598   EXPECT_EQ("superchat", stream->GetSubProtocol());
599 }
600 
601 // Verify that extension name is available. Bad extension names are tested in
602 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,Extensions)603 TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
604   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
605       {}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
606   EXPECT_EQ("permessage-deflate", stream->GetExtensions());
607 }
608 
609 // Verify that extension parameters are available. Bad parameters are tested in
610 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,ExtensionParameters)611 TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
612   std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
613       {}, {},
614       {{"Sec-WebSocket-Extensions",
615         "permessage-deflate;"
616         " client_max_window_bits=14; server_max_window_bits=14;"
617         " server_no_context_takeover; client_no_context_takeover"}});
618 
619   EXPECT_EQ(
620       "permessage-deflate;"
621       " client_max_window_bits=14; server_max_window_bits=14;"
622       " server_no_context_takeover; client_no_context_takeover",
623       stream->GetExtensions());
624 }
625 
626 }  // namespace
627 
628 }  // namespace net
629