1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_handshake_stream_create_helper.h"
6
7 #include <optional>
8 #include <string>
9 #include <utility>
10 #include <vector>
11
12 #include "base/containers/span.h"
13 #include "base/functional/callback.h"
14 #include "base/memory/scoped_refptr.h"
15 #include "base/notreached.h"
16 #include "base/strings/string_piece.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "base/time/default_tick_clock.h"
19 #include "base/time/time.h"
20 #include "net/base/auth.h"
21 #include "net/base/completion_once_callback.h"
22 #include "net/base/connection_endpoint_metadata.h"
23 #include "net/base/host_port_pair.h"
24 #include "net/base/ip_address.h"
25 #include "net/base/ip_endpoint.h"
26 #include "net/base/load_flags.h"
27 #include "net/base/net_errors.h"
28 #include "net/base/network_anonymization_key.h"
29 #include "net/base/network_handle.h"
30 #include "net/base/privacy_mode.h"
31 #include "net/base/proxy_chain.h"
32 #include "net/base/proxy_server.h"
33 #include "net/base/request_priority.h"
34 #include "net/base/session_usage.h"
35 #include "net/base/test_completion_callback.h"
36 #include "net/cert/cert_verify_result.h"
37 #include "net/dns/public/host_resolver_results.h"
38 #include "net/dns/public/secure_dns_policy.h"
39 #include "net/http/http_request_info.h"
40 #include "net/http/http_response_headers.h"
41 #include "net/http/http_response_info.h"
42 #include "net/http/transport_security_state.h"
43 #include "net/log/net_log.h"
44 #include "net/log/net_log_with_source.h"
45 #include "net/quic/address_utils.h"
46 #include "net/quic/crypto/proof_verifier_chromium.h"
47 #include "net/quic/mock_crypto_client_stream_factory.h"
48 #include "net/quic/mock_quic_data.h"
49 #include "net/quic/quic_chromium_alarm_factory.h"
50 #include "net/quic/quic_chromium_connection_helper.h"
51 #include "net/quic/quic_chromium_packet_reader.h"
52 #include "net/quic/quic_chromium_packet_writer.h"
53 #include "net/quic/quic_context.h"
54 #include "net/quic/quic_http_utils.h"
55 #include "net/quic/quic_server_info.h"
56 #include "net/quic/quic_session_key.h"
57 #include "net/quic/quic_test_packet_maker.h"
58 #include "net/quic/test_quic_crypto_client_config_handle.h"
59 #include "net/quic/test_task_runner.h"
60 #include "net/socket/client_socket_handle.h"
61 #include "net/socket/client_socket_pool.h"
62 #include "net/socket/connect_job.h"
63 #include "net/socket/socket_tag.h"
64 #include "net/socket/socket_test_util.h"
65 #include "net/socket/websocket_endpoint_lock_manager.h"
66 #include "net/spdy/spdy_session_key.h"
67 #include "net/spdy/spdy_test_util_common.h"
68 #include "net/ssl/ssl_config_service_defaults.h"
69 #include "net/ssl/ssl_info.h"
70 #include "net/test/cert_test_util.h"
71 #include "net/test/gtest_util.h"
72 #include "net/test/test_data_directory.h"
73 #include "net/test/test_with_task_environment.h"
74 #include "net/third_party/quiche/src/quiche/common/platform/api/quiche_flags.h"
75 #include "net/third_party/quiche/src/quiche/quic/core/crypto/quic_crypto_client_config.h"
76 #include "net/third_party/quiche/src/quiche/quic/core/qpack/qpack_decoder.h"
77 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection.h"
78 #include "net/third_party/quiche/src/quiche/quic/core/quic_connection_id.h"
79 #include "net/third_party/quiche/src/quiche/quic/core/quic_error_codes.h"
80 #include "net/third_party/quiche/src/quiche/quic/core/quic_packets.h"
81 #include "net/third_party/quiche/src/quiche/quic/core/quic_time.h"
82 #include "net/third_party/quiche/src/quiche/quic/core/quic_types.h"
83 #include "net/third_party/quiche/src/quiche/quic/core/quic_utils.h"
84 #include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
85 #include "net/third_party/quiche/src/quiche/quic/platform/api/quic_socket_address.h"
86 #include "net/third_party/quiche/src/quiche/quic/test_tools/crypto_test_utils.h"
87 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_clock.h"
88 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_connection_id_generator.h"
89 #include "net/third_party/quiche/src/quiche/quic/test_tools/mock_random.h"
90 #include "net/third_party/quiche/src/quiche/quic/test_tools/qpack/qpack_test_utils.h"
91 #include "net/third_party/quiche/src/quiche/quic/test_tools/quic_test_utils.h"
92 #include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h"
93 #include "net/third_party/quiche/src/quiche/spdy/core/spdy_protocol.h"
94 #include "net/traffic_annotation/network_traffic_annotation.h"
95 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
96 #include "net/websockets/websocket_basic_handshake_stream.h"
97 #include "net/websockets/websocket_event_interface.h"
98 #include "net/websockets/websocket_stream.h"
99 #include "net/websockets/websocket_test_util.h"
100 #include "testing/gmock/include/gmock/gmock.h"
101 #include "testing/gtest/include/gtest/gtest.h"
102 #include "url/gurl.h"
103 #include "url/origin.h"
104 #include "url/scheme_host_port.h"
105 #include "url/url_constants.h"
106
107 namespace net {
108 class HttpNetworkSession;
109 class URLRequest;
110 class WebSocketHttp2HandshakeStream;
111 class WebSocketHttp3HandshakeStream;
112 class X509Certificate;
113 struct WebSocketHandshakeRequestInfo;
114 struct WebSocketHandshakeResponseInfo;
115 } // namespace net
116
117 using ::net::test::IsError;
118 using ::net::test::IsOk;
119 using ::testing::_;
120 using ::testing::StrictMock;
121 using ::testing::TestWithParam;
122 using ::testing::Values;
123
124 namespace net {
125 namespace {
126
127 enum HandshakeStreamType {
128 BASIC_HANDSHAKE_STREAM,
129 HTTP2_HANDSHAKE_STREAM,
130 HTTP3_HANDSHAKE_STREAM
131 };
132
133 // This class encapsulates the details of creating a mock ClientSocketHandle.
134 class MockClientSocketHandleFactory {
135 public:
MockClientSocketHandleFactory()136 MockClientSocketHandleFactory()
137 : common_connect_job_params_(
138 socket_factory_maker_.factory(),
139 /*host_resolver=*/nullptr,
140 /*http_auth_cache=*/nullptr,
141 /*http_auth_handler_factory=*/nullptr,
142 /*spdy_session_pool=*/nullptr,
143 /*quic_supported_versions=*/nullptr,
144 /*quic_session_pool=*/nullptr,
145 /*proxy_delegate=*/nullptr,
146 /*http_user_agent_settings=*/nullptr,
147 /*ssl_client_context=*/nullptr,
148 /*socket_performance_watcher_factory=*/nullptr,
149 /*network_quality_estimator=*/nullptr,
150 /*net_log=*/nullptr,
151 /*websocket_endpoint_lock_manager=*/nullptr,
152 /*http_server_properties=*/nullptr,
153 /*alpn_protos=*/nullptr,
154 /*application_settings=*/nullptr,
155 /*ignore_certificate_errors=*/nullptr,
156 /*early_data_enabled=*/nullptr),
157 pool_(1, 1, &common_connect_job_params_) {}
158
159 MockClientSocketHandleFactory(const MockClientSocketHandleFactory&) = delete;
160 MockClientSocketHandleFactory& operator=(
161 const MockClientSocketHandleFactory&) = delete;
162
163 // The created socket expects |expect_written| to be written to the socket,
164 // and will respond with |return_to_read|. The test will fail if the expected
165 // text is not written, or if all the bytes are not read.
CreateClientSocketHandle(const std::string & expect_written,const std::string & return_to_read)166 std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
167 const std::string& expect_written,
168 const std::string& return_to_read) {
169 socket_factory_maker_.SetExpectations(expect_written, return_to_read);
170 auto socket_handle = std::make_unique<ClientSocketHandle>();
171 socket_handle->Init(
172 ClientSocketPool::GroupId(
173 url::SchemeHostPort(url::kHttpScheme, "a", 80),
174 PrivacyMode::PRIVACY_MODE_DISABLED, NetworkAnonymizationKey(),
175 SecureDnsPolicy::kAllow, /*disable_cert_network_fetches=*/false),
176 scoped_refptr<ClientSocketPool::SocketParams>(),
177 std::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
178 ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
179 ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
180 return socket_handle;
181 }
182
183 private:
184 WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
185 const CommonConnectJobParams common_connect_job_params_;
186 MockTransportClientSocketPool pool_;
187 };
188
189 class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
190 public:
191 ~TestConnectDelegate() override = default;
192
OnCreateRequest(URLRequest * request)193 void OnCreateRequest(URLRequest* request) override {}
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)194 void OnURLRequestConnected(URLRequest* request,
195 const TransportInfo& info) override {}
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)196 void OnSuccess(
197 std::unique_ptr<WebSocketStream> stream,
198 std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
OnFailure(const std::string & failure_message,int net_error,std::optional<int> response_code)199 void OnFailure(const std::string& failure_message,
200 int net_error,
201 std::optional<int> response_code) override {}
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)202 void OnStartOpeningHandshake(
203 std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)204 void OnSSLCertificateError(
205 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
206 ssl_error_callbacks,
207 int net_error,
208 const SSLInfo& ssl_info,
209 bool fatal) override {}
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)210 int OnAuthRequired(const AuthChallengeInfo& auth_info,
211 scoped_refptr<HttpResponseHeaders> response_headers,
212 const IPEndPoint& host_port_pair,
213 base::OnceCallback<void(const AuthCredentials*)> callback,
214 std::optional<AuthCredentials>* credentials) override {
215 *credentials = std::nullopt;
216 return OK;
217 }
218 };
219
220 class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
221 public:
222 ~MockWebSocketStreamRequestAPI() override = default;
223
224 MOCK_METHOD1(OnBasicHandshakeStreamCreated,
225 void(WebSocketBasicHandshakeStream* handshake_stream));
226 MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
227 void(WebSocketHttp2HandshakeStream* handshake_stream));
228 MOCK_METHOD1(OnHttp3HandshakeStreamCreated,
229 void(WebSocketHttp3HandshakeStream* handshake_stream));
230 MOCK_METHOD3(OnFailure,
231 void(const std::string& message,
232 int net_error,
233 std::optional<int> response_code));
234 };
235
236 class WebSocketHandshakeStreamCreateHelperTest
237 : public TestWithParam<HandshakeStreamType>,
238 public WithTaskEnvironment {
239 protected:
WebSocketHandshakeStreamCreateHelperTest()240 WebSocketHandshakeStreamCreateHelperTest()
241 : quic_version_(quic::HandshakeProtocol::PROTOCOL_TLS1_3,
242 quic::QuicTransportVersion::QUIC_VERSION_IETF_RFC_V1),
243 mock_quic_data_(quic_version_) {}
CreateAndInitializeStream(const std::vector<std::string> & sub_protocols,const WebSocketExtraHeaders & extra_request_headers,const WebSocketExtraHeaders & extra_response_headers)244 std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
245 const std::vector<std::string>& sub_protocols,
246 const WebSocketExtraHeaders& extra_request_headers,
247 const WebSocketExtraHeaders& extra_response_headers) {
248 constexpr char kPath[] = "/";
249 constexpr char kOrigin[] = "http://origin.example.org";
250 const GURL url("wss://www.example.org/");
251 NetLogWithSource net_log;
252
253 WebSocketHandshakeStreamCreateHelper create_helper(
254 &connect_delegate_, sub_protocols, &stream_request_);
255
256 switch (GetParam()) {
257 case BASIC_HANDSHAKE_STREAM:
258 EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
259 break;
260
261 case HTTP2_HANDSHAKE_STREAM:
262 EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
263 break;
264
265 case HTTP3_HANDSHAKE_STREAM:
266 EXPECT_CALL(stream_request_, OnHttp3HandshakeStreamCreated(_)).Times(1);
267 break;
268
269 default:
270 NOTREACHED();
271 }
272
273 EXPECT_CALL(stream_request_, OnFailure(_, _, _)).Times(0);
274
275 HttpRequestInfo request_info;
276 request_info.url = url;
277 request_info.method = "GET";
278 request_info.load_flags = LOAD_DISABLE_CACHE;
279 request_info.traffic_annotation =
280 MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
281
282 auto headers = WebSocketCommonTestHeaders();
283
284 switch (GetParam()) {
285 case BASIC_HANDSHAKE_STREAM: {
286 std::unique_ptr<ClientSocketHandle> socket_handle =
287 socket_handle_factory_.CreateClientSocketHandle(
288 WebSocketStandardRequest(kPath, "www.example.org",
289 url::Origin::Create(GURL(kOrigin)),
290 /*send_additional_request_headers=*/{},
291 extra_request_headers),
292 WebSocketStandardResponse(
293 WebSocketExtraHeadersToString(extra_response_headers)));
294
295 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
296 create_helper.CreateBasicStream(std::move(socket_handle), false,
297 &websocket_endpoint_lock_manager_);
298
299 // If in future the implementation type returned by CreateBasicStream()
300 // changes, this static_cast will be wrong. However, in that case the
301 // test will fail and AddressSanitizer should identify the issue.
302 static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
303 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
304
305 handshake->RegisterRequest(&request_info);
306 int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
307 CompletionOnceCallback());
308 EXPECT_THAT(rv, IsOk());
309
310 HttpResponseInfo response;
311 TestCompletionCallback request_callback;
312 rv = handshake->SendRequest(headers, &response,
313 request_callback.callback());
314 EXPECT_THAT(rv, IsOk());
315
316 TestCompletionCallback response_callback;
317 rv = handshake->ReadResponseHeaders(response_callback.callback());
318 EXPECT_THAT(rv, IsOk());
319 EXPECT_EQ(101, response.headers->response_code());
320 EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
321 EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
322 return handshake->Upgrade();
323 }
324 case HTTP2_HANDSHAKE_STREAM: {
325 SpdyTestUtil spdy_util;
326 spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
327 kPath, "www.example.org", kOrigin, extra_request_headers);
328 spdy::SpdySerializedFrame request_headers(
329 spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
330 DEFAULT_PRIORITY, false));
331 MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
332
333 spdy::Http2HeaderBlock response_header_block =
334 WebSocketHttp2Response(extra_response_headers);
335 spdy::SpdySerializedFrame response_headers(
336 spdy_util.ConstructSpdyResponseHeaders(
337 1, std::move(response_header_block), false));
338 MockRead reads[] = {CreateMockRead(response_headers, 1),
339 MockRead(ASYNC, 0, 2)};
340
341 SequencedSocketData data(reads, writes);
342
343 SSLSocketDataProvider ssl(ASYNC, OK);
344 ssl.ssl_info.cert =
345 ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
346
347 SpdySessionDependencies session_deps;
348 session_deps.socket_factory->AddSocketDataProvider(&data);
349 session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
350
351 std::unique_ptr<HttpNetworkSession> http_network_session =
352 SpdySessionDependencies::SpdyCreateSession(&session_deps);
353 const SpdySessionKey key(
354 HostPortPair::FromURL(url), PRIVACY_MODE_DISABLED,
355 ProxyChain::Direct(), SessionUsage::kDestination, SocketTag(),
356 NetworkAnonymizationKey(), SecureDnsPolicy::kAllow,
357 /*disable_cert_verification_network_fetches=*/false);
358 base::WeakPtr<SpdySession> spdy_session =
359 CreateSpdySession(http_network_session.get(), key, net_log);
360 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
361 create_helper.CreateHttp2Stream(spdy_session, {} /* dns_aliases */);
362
363 handshake->RegisterRequest(&request_info);
364 int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY,
365 NetLogWithSource(),
366 CompletionOnceCallback());
367 EXPECT_THAT(rv, IsOk());
368
369 HttpResponseInfo response;
370 TestCompletionCallback request_callback;
371 rv = handshake->SendRequest(headers, &response,
372 request_callback.callback());
373 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
374 rv = request_callback.WaitForResult();
375 EXPECT_THAT(rv, IsOk());
376
377 TestCompletionCallback response_callback;
378 rv = handshake->ReadResponseHeaders(response_callback.callback());
379 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
380 rv = response_callback.WaitForResult();
381 EXPECT_THAT(rv, IsOk());
382
383 EXPECT_EQ(200, response.headers->response_code());
384 return handshake->Upgrade();
385 }
386 case HTTP3_HANDSHAKE_STREAM: {
387 const quic::QuicStreamId client_data_stream_id(
388 quic::QuicUtils::GetFirstBidirectionalStreamId(
389 quic_version_.transport_version, quic::Perspective::IS_CLIENT));
390 quic::QuicCryptoClientConfig crypto_config(
391 quic::test::crypto_test_utils::ProofVerifierForTesting());
392
393 const quic::QuicConnectionId connection_id(
394 quic::test::TestConnectionId(2));
395 test::QuicTestPacketMaker client_maker(
396 quic_version_, connection_id, &clock_, "mail.example.org",
397 quic::Perspective::IS_CLIENT,
398 /*client_headers_include_h2_stream_dependency_=*/false);
399 test::QuicTestPacketMaker server_maker(
400 quic_version_, connection_id, &clock_, "mail.example.org",
401 quic::Perspective::IS_SERVER,
402 /*client_headers_include_h2_stream_dependency_=*/false);
403 IPEndPoint peer_addr(IPAddress(192, 0, 2, 23), 443);
404 quic::test::MockConnectionIdGenerator connection_id_generator;
405
406 testing::StrictMock<quic::test::MockQuicConnectionVisitor> visitor;
407 ProofVerifyDetailsChromium verify_details;
408 MockCryptoClientStreamFactory crypto_client_stream_factory;
409 TransportSecurityState transport_security_state;
410 SSLConfigServiceDefaults ssl_config_service;
411
412 FLAGS_quic_enable_http3_grease_randomness = false;
413 clock_.AdvanceTime(quic::QuicTime::Delta::FromMilliseconds(20));
414 quic::QuicEnableVersion(quic_version_);
415 quic::test::MockRandom random_generator{0};
416
417 spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
418 kPath, "www.example.org", kOrigin, extra_request_headers);
419
420 int packet_number = 1;
421 mock_quic_data_.AddWrite(
422 SYNCHRONOUS,
423 client_maker.MakeInitialSettingsPacket(packet_number++));
424
425 mock_quic_data_.AddWrite(
426 ASYNC,
427 client_maker.MakeRequestHeadersPacket(
428 packet_number++, client_data_stream_id,
429 /*fin=*/false, ConvertRequestPriorityToQuicPriority(LOWEST),
430 std::move(request_header_block), nullptr));
431
432 spdy::Http2HeaderBlock response_header_block =
433 WebSocketHttp2Response(extra_response_headers);
434
435 mock_quic_data_.AddRead(
436 ASYNC, server_maker.MakeResponseHeadersPacket(
437 /*packet_number=*/1, client_data_stream_id,
438 /*fin=*/false, std::move(response_header_block),
439 /*spdy_headers_frame_length=*/nullptr));
440
441 mock_quic_data_.AddRead(SYNCHRONOUS, ERR_IO_PENDING);
442
443 mock_quic_data_.AddWrite(SYNCHRONOUS,
444 client_maker.MakeAckAndRstPacket(
445 packet_number++, client_data_stream_id,
446 quic::QUIC_STREAM_CANCELLED, 1, 0,
447 /*include_stop_sending_if_v99=*/true));
448 auto socket = std::make_unique<MockUDPClientSocket>(
449 mock_quic_data_.InitializeAndGetSequencedSocketData(),
450 NetLog::Get());
451 socket->Connect(peer_addr);
452
453 scoped_refptr<test::TestTaskRunner> runner =
454 base::MakeRefCounted<test::TestTaskRunner>(&clock_);
455 auto helper = std::make_unique<QuicChromiumConnectionHelper>(
456 &clock_, &random_generator);
457 auto alarm_factory =
458 std::make_unique<QuicChromiumAlarmFactory>(runner.get(), &clock_);
459 // Ownership of 'writer' is passed to 'QuicConnection'.
460 QuicChromiumPacketWriter* writer = new QuicChromiumPacketWriter(
461 socket.get(),
462 base::SingleThreadTaskRunner::GetCurrentDefault().get());
463 quic::QuicConnection* connection = new quic::QuicConnection(
464 connection_id, quic::QuicSocketAddress(),
465 net::ToQuicSocketAddress(peer_addr), helper.get(),
466 alarm_factory.get(), writer, true /* owns_writer */,
467 quic::Perspective::IS_CLIENT,
468 quic::test::SupportedVersions(quic_version_),
469 connection_id_generator);
470 connection->set_visitor(&visitor);
471
472 // Load a certificate that is valid for *.example.org
473 scoped_refptr<X509Certificate> test_cert(
474 ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem"));
475 EXPECT_TRUE(test_cert.get());
476
477 verify_details.cert_verify_result.verified_cert = test_cert;
478 verify_details.cert_verify_result.is_issued_by_known_root = true;
479 crypto_client_stream_factory.AddProofVerifyDetails(&verify_details);
480
481 base::TimeTicks dns_end = base::TimeTicks::Now();
482 base::TimeTicks dns_start = dns_end - base::Milliseconds(1);
483
484 session_ = std::make_unique<QuicChromiumClientSession>(
485 connection, std::move(socket),
486 /*stream_factory=*/nullptr, &crypto_client_stream_factory, &clock_,
487 &transport_security_state, &ssl_config_service,
488 /*server_info=*/nullptr,
489 QuicSessionKey("mail.example.org", 80, PRIVACY_MODE_DISABLED,
490 ProxyChain::Direct(), SessionUsage::kDestination,
491 SocketTag(), NetworkAnonymizationKey(),
492 SecureDnsPolicy::kAllow,
493 /*require_dns_https_alpn=*/false),
494 /*require_confirmation=*/false,
495 /*migrate_session_early_v2=*/false,
496 /*migrate_session_on_network_change_v2=*/false,
497 /*default_network=*/handles::kInvalidNetworkHandle,
498 quic::QuicTime::Delta::FromMilliseconds(
499 kDefaultRetransmittableOnWireTimeout.InMilliseconds()),
500 /*migrate_idle_session=*/true, /*allow_port_migration=*/false,
501 kDefaultIdleSessionMigrationPeriod,
502 /*multi_port_probing_interval=*/0, kMaxTimeOnNonDefaultNetwork,
503 kMaxMigrationsToNonDefaultNetworkOnWriteError,
504 kMaxMigrationsToNonDefaultNetworkOnPathDegrading,
505 kQuicYieldAfterPacketsRead,
506 quic::QuicTime::Delta::FromMilliseconds(
507 kQuicYieldAfterDurationMilliseconds),
508 /*cert_verify_flags=*/0, quic::test::DefaultQuicConfig(),
509 std::make_unique<TestQuicCryptoClientConfigHandle>(&crypto_config),
510 "CONNECTION_UNKNOWN", dns_start, dns_end,
511 base::DefaultTickClock::GetInstance(),
512 base::SingleThreadTaskRunner::GetCurrentDefault().get(),
513 /*socket_performance_watcher=*/nullptr,
514 ConnectionEndpointMetadata(),
515 NetLogWithSource::Make(NetLogSourceType::NONE));
516
517 session_->Initialize();
518
519 // Blackhole QPACK decoder stream instead of constructing mock writes.
520 session_->qpack_decoder()->set_qpack_stream_sender_delegate(
521 &noop_qpack_stream_sender_delegate_);
522 TestCompletionCallback callback;
523 EXPECT_THAT(session_->CryptoConnect(callback.callback()), IsOk());
524 EXPECT_TRUE(session_->OneRttKeysAvailable());
525 std::unique_ptr<QuicChromiumClientSession::Handle> session_handle =
526 session_->CreateHandle(
527 url::SchemeHostPort(url::kHttpsScheme, "mail.example.org", 80));
528
529 std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
530 create_helper.CreateHttp3Stream(std::move(session_handle),
531 {} /* dns_aliases */);
532
533 handshake->RegisterRequest(&request_info);
534 int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
535 CompletionOnceCallback());
536 EXPECT_THAT(rv, IsOk());
537
538 HttpResponseInfo response;
539 TestCompletionCallback request_callback;
540 rv = handshake->SendRequest(headers, &response,
541 request_callback.callback());
542 EXPECT_THAT(rv, IsOk());
543
544 session_->StartReading();
545
546 TestCompletionCallback response_callback;
547 rv = handshake->ReadResponseHeaders(response_callback.callback());
548 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
549 rv = response_callback.WaitForResult();
550 EXPECT_THAT(rv, IsOk());
551
552 EXPECT_EQ(200, response.headers->response_code());
553
554 return handshake->Upgrade();
555 }
556 default:
557 NOTREACHED();
558 return nullptr;
559 }
560 }
561
562 private:
563 MockClientSocketHandleFactory socket_handle_factory_;
564 TestConnectDelegate connect_delegate_;
565 StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
566 WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
567
568 // For HTTP3_HANDSHAKE_STREAM
569 quic::ParsedQuicVersion quic_version_;
570 quic::MockClock clock_;
571 std::unique_ptr<QuicChromiumClientSession> session_;
572 test::MockQuicData mock_quic_data_;
573 quic::test::NoopQpackStreamSenderDelegate noop_qpack_stream_sender_delegate_;
574 };
575
576 INSTANTIATE_TEST_SUITE_P(All,
577 WebSocketHandshakeStreamCreateHelperTest,
578 Values(BASIC_HANDSHAKE_STREAM,
579 HTTP2_HANDSHAKE_STREAM,
580 HTTP3_HANDSHAKE_STREAM));
581
582 // Confirm that the basic case works as expected.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,BasicStream)583 TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
584 std::unique_ptr<WebSocketStream> stream =
585 CreateAndInitializeStream({}, {}, {});
586 EXPECT_EQ("", stream->GetExtensions());
587 EXPECT_EQ("", stream->GetSubProtocol());
588 }
589
590 // Verify that the sub-protocols are passed through.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,SubProtocols)591 TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
592 std::vector<std::string> sub_protocols;
593 sub_protocols.push_back("chat");
594 sub_protocols.push_back("superchat");
595 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
596 sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
597 {{"Sec-WebSocket-Protocol", "superchat"}});
598 EXPECT_EQ("superchat", stream->GetSubProtocol());
599 }
600
601 // Verify that extension name is available. Bad extension names are tested in
602 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,Extensions)603 TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
604 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
605 {}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
606 EXPECT_EQ("permessage-deflate", stream->GetExtensions());
607 }
608
609 // Verify that extension parameters are available. Bad parameters are tested in
610 // websocket_stream_test.cc.
TEST_P(WebSocketHandshakeStreamCreateHelperTest,ExtensionParameters)611 TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
612 std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
613 {}, {},
614 {{"Sec-WebSocket-Extensions",
615 "permessage-deflate;"
616 " client_max_window_bits=14; server_max_window_bits=14;"
617 " server_no_context_takeover; client_no_context_takeover"}});
618
619 EXPECT_EQ(
620 "permessage-deflate;"
621 " client_max_window_bits=14; server_max_window_bits=14;"
622 " server_no_context_takeover; client_no_context_takeover",
623 stream->GetExtensions());
624 }
625
626 } // namespace
627
628 } // namespace net
629