1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 // Connects SSL socket to FakeDataChannel. This class is just a stub.
10 //
11 // 2. FakeDataChannel
12 // Implements the actual exchange of data between two FakeSockets.
13 //
14 // Implementations of these two classes are included in this file.
15
16 #include "net/socket/ssl_server_socket.h"
17
18 #include <stdint.h>
19 #include <stdlib.h>
20
21 #include <memory>
22 #include <string_view>
23 #include <utility>
24
25 #include "base/check.h"
26 #include "base/compiler_specific.h"
27 #include "base/containers/queue.h"
28 #include "base/files/file_path.h"
29 #include "base/files/file_util.h"
30 #include "base/functional/bind.h"
31 #include "base/functional/callback_helpers.h"
32 #include "base/location.h"
33 #include "base/memory/raw_ptr.h"
34 #include "base/memory/scoped_refptr.h"
35 #include "base/notreached.h"
36 #include "base/run_loop.h"
37 #include "base/task/single_thread_task_runner.h"
38 #include "base/test/bind.h"
39 #include "base/test/task_environment.h"
40 #include "build/build_config.h"
41 #include "crypto/rsa_private_key.h"
42 #include "net/base/address_list.h"
43 #include "net/base/completion_once_callback.h"
44 #include "net/base/host_port_pair.h"
45 #include "net/base/io_buffer.h"
46 #include "net/base/ip_address.h"
47 #include "net/base/ip_endpoint.h"
48 #include "net/base/net_errors.h"
49 #include "net/cert/cert_status_flags.h"
50 #include "net/cert/mock_cert_verifier.h"
51 #include "net/cert/mock_client_cert_verifier.h"
52 #include "net/cert/signed_certificate_timestamp_and_status.h"
53 #include "net/cert/x509_certificate.h"
54 #include "net/http/transport_security_state.h"
55 #include "net/log/net_log_with_source.h"
56 #include "net/socket/client_socket_factory.h"
57 #include "net/socket/socket_test_util.h"
58 #include "net/socket/ssl_client_socket.h"
59 #include "net/socket/stream_socket.h"
60 #include "net/ssl/openssl_private_key.h"
61 #include "net/ssl/ssl_cert_request_info.h"
62 #include "net/ssl/ssl_cipher_suite_names.h"
63 #include "net/ssl/ssl_client_session_cache.h"
64 #include "net/ssl/ssl_connection_status_flags.h"
65 #include "net/ssl/ssl_info.h"
66 #include "net/ssl/ssl_private_key.h"
67 #include "net/ssl/ssl_server_config.h"
68 #include "net/ssl/test_ssl_config_service.h"
69 #include "net/test/cert_test_util.h"
70 #include "net/test/gtest_util.h"
71 #include "net/test/test_data_directory.h"
72 #include "net/test/test_with_task_environment.h"
73 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
74 #include "testing/gmock/include/gmock/gmock.h"
75 #include "testing/gtest/include/gtest/gtest.h"
76 #include "testing/platform_test.h"
77 #include "third_party/boringssl/src/include/openssl/evp.h"
78 #include "third_party/boringssl/src/include/openssl/ssl.h"
79
80 using net::test::IsError;
81 using net::test::IsOk;
82
83 namespace net {
84
85 namespace {
86
87 // Client certificates are disabled on iOS.
88 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
89 const char kClientCertFileName[] = "client_1.pem";
90 const char kClientPrivateKeyFileName[] = "client_1.pk8";
91 const char kWrongClientCertFileName[] = "client_2.pem";
92 const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
93 #endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
94
95 const uint16_t kEcdheCiphers[] = {
96 0xc007, // ECDHE_ECDSA_WITH_RC4_128_SHA
97 0xc009, // ECDHE_ECDSA_WITH_AES_128_CBC_SHA
98 0xc00a, // ECDHE_ECDSA_WITH_AES_256_CBC_SHA
99 0xc011, // ECDHE_RSA_WITH_RC4_128_SHA
100 0xc013, // ECDHE_RSA_WITH_AES_128_CBC_SHA
101 0xc014, // ECDHE_RSA_WITH_AES_256_CBC_SHA
102 0xc02b, // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
103 0xc02c, // ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
104 0xc02f, // ECDHE_RSA_WITH_AES_128_GCM_SHA256
105 0xc030, // ECDHE_RSA_WITH_AES_256_GCM_SHA384
106 0xcca8, // ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
107 0xcca9, // ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
108 };
109
110 class FakeDataChannel {
111 public:
112 FakeDataChannel() = default;
113
114 FakeDataChannel(const FakeDataChannel&) = delete;
115 FakeDataChannel& operator=(const FakeDataChannel&) = delete;
116
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)117 int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) {
118 DCHECK(read_callback_.is_null());
119 DCHECK(!read_buf_.get());
120 if (closed_)
121 return 0;
122 if (data_.empty()) {
123 read_callback_ = std::move(callback);
124 read_buf_ = buf;
125 read_buf_len_ = buf_len;
126 return ERR_IO_PENDING;
127 }
128 return PropagateData(buf, buf_len);
129 }
130
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)131 int Write(IOBuffer* buf,
132 int buf_len,
133 CompletionOnceCallback callback,
134 const NetworkTrafficAnnotationTag& traffic_annotation) {
135 DCHECK(write_callback_.is_null());
136 if (closed_) {
137 if (write_called_after_close_)
138 return ERR_CONNECTION_RESET;
139 write_called_after_close_ = true;
140 write_callback_ = std::move(callback);
141 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
142 FROM_HERE, base::BindOnce(&FakeDataChannel::DoWriteCallback,
143 weak_factory_.GetWeakPtr()));
144 return ERR_IO_PENDING;
145 }
146 // This function returns synchronously, so make a copy of the buffer.
147 data_.push(base::MakeRefCounted<DrainableIOBuffer>(
148 base::MakeRefCounted<StringIOBuffer>(std::string(buf->data(), buf_len)),
149 buf_len));
150 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
151 FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
152 weak_factory_.GetWeakPtr()));
153 return buf_len;
154 }
155
156 // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
157 // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
158 // after the FakeDataChannel is closed, the first Write() call completes
159 // asynchronously, which is necessary to reproduce bug 127822.
Close()160 void Close() {
161 closed_ = true;
162 if (!read_callback_.is_null()) {
163 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
164 FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
165 weak_factory_.GetWeakPtr()));
166 }
167 }
168
169 private:
DoReadCallback()170 void DoReadCallback() {
171 if (read_callback_.is_null())
172 return;
173
174 if (closed_) {
175 std::move(read_callback_).Run(ERR_CONNECTION_CLOSED);
176 return;
177 }
178
179 if (data_.empty())
180 return;
181
182 int copied = PropagateData(read_buf_, read_buf_len_);
183 read_buf_ = nullptr;
184 read_buf_len_ = 0;
185 std::move(read_callback_).Run(copied);
186 }
187
DoWriteCallback()188 void DoWriteCallback() {
189 if (write_callback_.is_null())
190 return;
191
192 std::move(write_callback_).Run(ERR_CONNECTION_RESET);
193 }
194
PropagateData(scoped_refptr<IOBuffer> read_buf,int read_buf_len)195 int PropagateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
196 scoped_refptr<DrainableIOBuffer> buf = data_.front();
197 int copied = std::min(buf->BytesRemaining(), read_buf_len);
198 memcpy(read_buf->data(), buf->data(), copied);
199 buf->DidConsume(copied);
200
201 if (!buf->BytesRemaining())
202 data_.pop();
203 return copied;
204 }
205
206 CompletionOnceCallback read_callback_;
207 scoped_refptr<IOBuffer> read_buf_;
208 int read_buf_len_ = 0;
209
210 CompletionOnceCallback write_callback_;
211
212 base::queue<scoped_refptr<DrainableIOBuffer>> data_;
213
214 // True if Close() has been called.
215 bool closed_ = false;
216
217 // Controls the completion of Write() after the FakeDataChannel is closed.
218 // After the FakeDataChannel is closed, the first Write() call completes
219 // asynchronously.
220 bool write_called_after_close_ = false;
221
222 base::WeakPtrFactory<FakeDataChannel> weak_factory_{this};
223 };
224
225 class FakeSocket : public StreamSocket {
226 public:
FakeSocket(FakeDataChannel * incoming_channel,FakeDataChannel * outgoing_channel)227 FakeSocket(FakeDataChannel* incoming_channel,
228 FakeDataChannel* outgoing_channel)
229 : incoming_(incoming_channel), outgoing_(outgoing_channel) {}
230
231 FakeSocket(const FakeSocket&) = delete;
232 FakeSocket& operator=(const FakeSocket&) = delete;
233
234 ~FakeSocket() override = default;
235
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)236 int Read(IOBuffer* buf,
237 int buf_len,
238 CompletionOnceCallback callback) override {
239 // Read random number of bytes.
240 buf_len = rand() % buf_len + 1;
241 return incoming_->Read(buf, buf_len, std::move(callback));
242 }
243
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)244 int Write(IOBuffer* buf,
245 int buf_len,
246 CompletionOnceCallback callback,
247 const NetworkTrafficAnnotationTag& traffic_annotation) override {
248 // Write random number of bytes.
249 buf_len = rand() % buf_len + 1;
250 return outgoing_->Write(buf, buf_len, std::move(callback),
251 TRAFFIC_ANNOTATION_FOR_TESTS);
252 }
253
SetReceiveBufferSize(int32_t size)254 int SetReceiveBufferSize(int32_t size) override { return OK; }
255
SetSendBufferSize(int32_t size)256 int SetSendBufferSize(int32_t size) override { return OK; }
257
Connect(CompletionOnceCallback callback)258 int Connect(CompletionOnceCallback callback) override { return OK; }
259
Disconnect()260 void Disconnect() override {
261 incoming_->Close();
262 outgoing_->Close();
263 }
264
IsConnected() const265 bool IsConnected() const override { return true; }
266
IsConnectedAndIdle() const267 bool IsConnectedAndIdle() const override { return true; }
268
GetPeerAddress(IPEndPoint * address) const269 int GetPeerAddress(IPEndPoint* address) const override {
270 *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
271 return OK;
272 }
273
GetLocalAddress(IPEndPoint * address) const274 int GetLocalAddress(IPEndPoint* address) const override {
275 *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
276 return OK;
277 }
278
NetLog() const279 const NetLogWithSource& NetLog() const override { return net_log_; }
280
WasEverUsed() const281 bool WasEverUsed() const override { return true; }
282
GetNegotiatedProtocol() const283 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
284
GetSSLInfo(SSLInfo * ssl_info)285 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
286
GetTotalReceivedBytes() const287 int64_t GetTotalReceivedBytes() const override {
288 NOTIMPLEMENTED();
289 return 0;
290 }
291
ApplySocketTag(const SocketTag & tag)292 void ApplySocketTag(const SocketTag& tag) override {}
293
294 private:
295 NetLogWithSource net_log_;
296 raw_ptr<FakeDataChannel> incoming_;
297 raw_ptr<FakeDataChannel> outgoing_;
298 };
299
300 } // namespace
301
302 // Verify the correctness of the test helper classes first.
TEST(FakeSocketTest,DataTransfer)303 TEST(FakeSocketTest, DataTransfer) {
304 base::test::TaskEnvironment task_environment;
305
306 // Establish channels between two sockets.
307 FakeDataChannel channel_1;
308 FakeDataChannel channel_2;
309 FakeSocket client(&channel_1, &channel_2);
310 FakeSocket server(&channel_2, &channel_1);
311
312 const char kTestData[] = "testing123";
313 const int kTestDataSize = strlen(kTestData);
314 const int kReadBufSize = 1024;
315 auto write_buf = base::MakeRefCounted<StringIOBuffer>(kTestData);
316 auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
317
318 // Write then read.
319 int written =
320 server.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
321 TRAFFIC_ANNOTATION_FOR_TESTS);
322 EXPECT_GT(written, 0);
323 EXPECT_LE(written, kTestDataSize);
324
325 int read =
326 client.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
327 EXPECT_GT(read, 0);
328 EXPECT_LE(read, written);
329 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
330
331 // Read then write.
332 TestCompletionCallback callback;
333 EXPECT_EQ(ERR_IO_PENDING,
334 server.Read(read_buf.get(), kReadBufSize, callback.callback()));
335
336 written =
337 client.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
338 TRAFFIC_ANNOTATION_FOR_TESTS);
339 EXPECT_GT(written, 0);
340 EXPECT_LE(written, kTestDataSize);
341
342 read = callback.WaitForResult();
343 EXPECT_GT(read, 0);
344 EXPECT_LE(read, written);
345 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
346 }
347
348 class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment {
349 public:
SSLServerSocketTest()350 SSLServerSocketTest()
351 : ssl_config_service_(
352 std::make_unique<TestSSLConfigService>(SSLContextConfig())),
353 cert_verifier_(std::make_unique<MockCertVerifier>()),
354 client_cert_verifier_(std::make_unique<MockClientCertVerifier>()),
355 transport_security_state_(std::make_unique<TransportSecurityState>()),
356 ssl_client_session_cache_(std::make_unique<SSLClientSessionCache>(
357 SSLClientSessionCache::Config())) {}
358
SetUp()359 void SetUp() override {
360 PlatformTest::SetUp();
361
362 cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
363 client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
364
365 server_cert_ =
366 ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der");
367 ASSERT_TRUE(server_cert_);
368 server_private_key_ = ReadTestKey("unittest.key.bin");
369 ASSERT_TRUE(server_private_key_);
370
371 std::unique_ptr<crypto::RSAPrivateKey> key =
372 ReadTestKey("unittest.key.bin");
373 ASSERT_TRUE(key);
374 server_ssl_private_key_ = WrapOpenSSLPrivateKey(bssl::UpRef(key->key()));
375
376 // Certificate provided by the host doesn't need authority.
377 client_ssl_config_.allowed_bad_certs.emplace_back(
378 server_cert_, CERT_STATUS_AUTHORITY_INVALID);
379
380 client_context_ = std::make_unique<SSLClientContext>(
381 ssl_config_service_.get(), cert_verifier_.get(),
382 transport_security_state_.get(), ssl_client_session_cache_.get(),
383 nullptr);
384 }
385
386 protected:
CreateContext()387 void CreateContext() {
388 client_socket_.reset();
389 server_socket_.reset();
390 channel_1_.reset();
391 channel_2_.reset();
392 server_context_ = CreateSSLServerContext(
393 server_cert_.get(), *server_private_key_, server_ssl_config_);
394 }
395
CreateContextSSLPrivateKey()396 void CreateContextSSLPrivateKey() {
397 client_socket_.reset();
398 server_socket_.reset();
399 channel_1_.reset();
400 channel_2_.reset();
401 server_context_.reset();
402 server_context_ = CreateSSLServerContext(
403 server_cert_.get(), server_ssl_private_key_, server_ssl_config_);
404 }
405
GetHostAndPort()406 static HostPortPair GetHostAndPort() { return HostPortPair("unittest", 0); }
407
CreateSockets()408 void CreateSockets() {
409 client_socket_.reset();
410 server_socket_.reset();
411 channel_1_ = std::make_unique<FakeDataChannel>();
412 channel_2_ = std::make_unique<FakeDataChannel>();
413 std::unique_ptr<StreamSocket> client_connection =
414 std::make_unique<FakeSocket>(channel_1_.get(), channel_2_.get());
415 std::unique_ptr<StreamSocket> server_socket =
416 std::make_unique<FakeSocket>(channel_2_.get(), channel_1_.get());
417
418 client_socket_ = client_context_->CreateSSLClientSocket(
419 std::move(client_connection), GetHostAndPort(), client_ssl_config_);
420 ASSERT_TRUE(client_socket_);
421
422 server_socket_ =
423 server_context_->CreateSSLServerSocket(std::move(server_socket));
424 ASSERT_TRUE(server_socket_);
425 }
426
427 // Client certificates are disabled on iOS.
428 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
ConfigureClientCertsForClient(const char * cert_file_name,const char * private_key_file_name)429 void ConfigureClientCertsForClient(const char* cert_file_name,
430 const char* private_key_file_name) {
431 scoped_refptr<X509Certificate> client_cert =
432 ImportCertFromFile(GetTestCertsDirectory(), cert_file_name);
433 ASSERT_TRUE(client_cert);
434
435 std::unique_ptr<crypto::RSAPrivateKey> key =
436 ReadTestKey(private_key_file_name);
437 ASSERT_TRUE(key);
438
439 client_context_->SetClientCertificate(
440 GetHostAndPort(), std::move(client_cert),
441 WrapOpenSSLPrivateKey(bssl::UpRef(key->key())));
442 }
443
ConfigureClientCertsForServer()444 void ConfigureClientCertsForServer() {
445 server_ssl_config_.client_cert_type =
446 SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT;
447
448 // "CN=B CA" - DER encoded DN of the issuer of client_1.pem
449 static const uint8_t kClientCertCAName[] = {
450 0x30, 0x0f, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55,
451 0x04, 0x03, 0x0c, 0x04, 0x42, 0x20, 0x43, 0x41};
452 server_ssl_config_.cert_authorities.emplace_back(
453 std::begin(kClientCertCAName), std::end(kClientCertCAName));
454
455 scoped_refptr<X509Certificate> expected_client_cert(
456 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName));
457 ASSERT_TRUE(expected_client_cert);
458
459 client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK);
460
461 server_ssl_config_.client_cert_verifier = client_cert_verifier_.get();
462 }
463 #endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
464
ReadTestKey(std::string_view name)465 std::unique_ptr<crypto::RSAPrivateKey> ReadTestKey(std::string_view name) {
466 base::FilePath certs_dir(GetTestCertsDirectory());
467 base::FilePath key_path = certs_dir.AppendASCII(name);
468 std::string key_string;
469 if (!base::ReadFileToString(key_path, &key_string))
470 return nullptr;
471 std::vector<uint8_t> key_vector(
472 reinterpret_cast<const uint8_t*>(key_string.data()),
473 reinterpret_cast<const uint8_t*>(key_string.data() +
474 key_string.length()));
475 std::unique_ptr<crypto::RSAPrivateKey> key(
476 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
477 return key;
478 }
479
PumpServerToClient()480 void PumpServerToClient() {
481 const int kReadBufSize = 1024;
482 scoped_refptr<StringIOBuffer> write_buf =
483 base::MakeRefCounted<StringIOBuffer>("testing123");
484 scoped_refptr<DrainableIOBuffer> read_buf =
485 base::MakeRefCounted<DrainableIOBuffer>(
486 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
487 TestCompletionCallback write_callback;
488 TestCompletionCallback read_callback;
489 int server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
490 write_callback.callback(),
491 TRAFFIC_ANNOTATION_FOR_TESTS);
492 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
493 int client_ret = client_socket_->Read(
494 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
495 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
496
497 server_ret = write_callback.GetResult(server_ret);
498 EXPECT_GT(server_ret, 0);
499 client_ret = read_callback.GetResult(client_ret);
500 ASSERT_GT(client_ret, 0);
501 }
502
503 std::unique_ptr<FakeDataChannel> channel_1_;
504 std::unique_ptr<FakeDataChannel> channel_2_;
505 std::unique_ptr<TestSSLConfigService> ssl_config_service_;
506 std::unique_ptr<MockCertVerifier> cert_verifier_;
507 std::unique_ptr<MockClientCertVerifier> client_cert_verifier_;
508 SSLConfig client_ssl_config_;
509 // Note that this has a pointer to the `cert_verifier_`, so must be destroyed
510 // before that is.
511 SSLServerConfig server_ssl_config_;
512 std::unique_ptr<TransportSecurityState> transport_security_state_;
513 std::unique_ptr<SSLClientSessionCache> ssl_client_session_cache_;
514 std::unique_ptr<SSLClientContext> client_context_;
515 std::unique_ptr<SSLServerContext> server_context_;
516 std::unique_ptr<SSLClientSocket> client_socket_;
517 std::unique_ptr<SSLServerSocket> server_socket_;
518 std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
519 scoped_refptr<SSLPrivateKey> server_ssl_private_key_;
520 scoped_refptr<X509Certificate> server_cert_;
521 };
522
523 class SSLServerSocketReadTest : public SSLServerSocketTest,
524 public ::testing::WithParamInterface<bool> {
525 protected:
SSLServerSocketReadTest()526 SSLServerSocketReadTest() : read_if_ready_enabled_(GetParam()) {}
527
Read(StreamSocket * socket,IOBuffer * buf,int buf_len,CompletionOnceCallback callback)528 int Read(StreamSocket* socket,
529 IOBuffer* buf,
530 int buf_len,
531 CompletionOnceCallback callback) {
532 if (read_if_ready_enabled()) {
533 return socket->ReadIfReady(buf, buf_len, std::move(callback));
534 }
535 return socket->Read(buf, buf_len, std::move(callback));
536 }
537
read_if_ready_enabled() const538 bool read_if_ready_enabled() const { return read_if_ready_enabled_; }
539
540 private:
541 const bool read_if_ready_enabled_;
542 };
543
544 INSTANTIATE_TEST_SUITE_P(/* no prefix */,
545 SSLServerSocketReadTest,
546 ::testing::Bool());
547
548 // This test only executes creation of client and server sockets. This is to
549 // test that creation of sockets doesn't crash and have minimal code to run
550 // with memory leak/corruption checking tools.
TEST_F(SSLServerSocketTest,Initialize)551 TEST_F(SSLServerSocketTest, Initialize) {
552 ASSERT_NO_FATAL_FAILURE(CreateContext());
553 ASSERT_NO_FATAL_FAILURE(CreateSockets());
554 }
555
556 // This test executes Connect() on SSLClientSocket and Handshake() on
557 // SSLServerSocket to make sure handshaking between the two sockets is
558 // completed successfully.
TEST_F(SSLServerSocketTest,Handshake)559 TEST_F(SSLServerSocketTest, Handshake) {
560 ASSERT_NO_FATAL_FAILURE(CreateContext());
561 ASSERT_NO_FATAL_FAILURE(CreateSockets());
562
563 TestCompletionCallback handshake_callback;
564 int server_ret = server_socket_->Handshake(handshake_callback.callback());
565
566 TestCompletionCallback connect_callback;
567 int client_ret = client_socket_->Connect(connect_callback.callback());
568
569 client_ret = connect_callback.GetResult(client_ret);
570 server_ret = handshake_callback.GetResult(server_ret);
571
572 ASSERT_THAT(client_ret, IsOk());
573 ASSERT_THAT(server_ret, IsOk());
574
575 // Make sure the cert status is expected.
576 SSLInfo ssl_info;
577 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
578 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
579
580 // The default cipher suite should be ECDHE and an AEAD.
581 uint16_t cipher_suite =
582 SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
583 const char* key_exchange;
584 const char* cipher;
585 const char* mac;
586 bool is_aead;
587 bool is_tls13;
588 SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
589 cipher_suite);
590 EXPECT_TRUE(is_aead);
591 }
592
593 // This test makes sure the session cache is working.
TEST_F(SSLServerSocketTest,HandshakeCached)594 TEST_F(SSLServerSocketTest, HandshakeCached) {
595 ASSERT_NO_FATAL_FAILURE(CreateContext());
596 ASSERT_NO_FATAL_FAILURE(CreateSockets());
597
598 TestCompletionCallback handshake_callback;
599 int server_ret = server_socket_->Handshake(handshake_callback.callback());
600
601 TestCompletionCallback connect_callback;
602 int client_ret = client_socket_->Connect(connect_callback.callback());
603
604 client_ret = connect_callback.GetResult(client_ret);
605 server_ret = handshake_callback.GetResult(server_ret);
606
607 ASSERT_THAT(client_ret, IsOk());
608 ASSERT_THAT(server_ret, IsOk());
609
610 // Make sure the cert status is expected.
611 SSLInfo ssl_info;
612 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
613 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
614 SSLInfo ssl_server_info;
615 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
616 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
617
618 // Pump client read to get new session tickets.
619 PumpServerToClient();
620
621 // Make sure the second connection is cached.
622 ASSERT_NO_FATAL_FAILURE(CreateSockets());
623 TestCompletionCallback handshake_callback2;
624 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
625
626 TestCompletionCallback connect_callback2;
627 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
628
629 client_ret2 = connect_callback2.GetResult(client_ret2);
630 server_ret2 = handshake_callback2.GetResult(server_ret2);
631
632 ASSERT_THAT(client_ret2, IsOk());
633 ASSERT_THAT(server_ret2, IsOk());
634
635 // Make sure the cert status is expected.
636 SSLInfo ssl_info2;
637 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
638 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
639 SSLInfo ssl_server_info2;
640 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
641 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
642 }
643
644 // This test makes sure the session cache separates out by server context.
TEST_F(SSLServerSocketTest,HandshakeCachedContextSwitch)645 TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) {
646 ASSERT_NO_FATAL_FAILURE(CreateContext());
647 ASSERT_NO_FATAL_FAILURE(CreateSockets());
648
649 TestCompletionCallback handshake_callback;
650 int server_ret = server_socket_->Handshake(handshake_callback.callback());
651
652 TestCompletionCallback connect_callback;
653 int client_ret = client_socket_->Connect(connect_callback.callback());
654
655 client_ret = connect_callback.GetResult(client_ret);
656 server_ret = handshake_callback.GetResult(server_ret);
657
658 ASSERT_THAT(client_ret, IsOk());
659 ASSERT_THAT(server_ret, IsOk());
660
661 // Make sure the cert status is expected.
662 SSLInfo ssl_info;
663 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
664 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
665 SSLInfo ssl_server_info;
666 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
667 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
668
669 // Make sure the second connection is NOT cached when using a new context.
670 ASSERT_NO_FATAL_FAILURE(CreateContext());
671 ASSERT_NO_FATAL_FAILURE(CreateSockets());
672
673 TestCompletionCallback handshake_callback2;
674 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
675
676 TestCompletionCallback connect_callback2;
677 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
678
679 client_ret2 = connect_callback2.GetResult(client_ret2);
680 server_ret2 = handshake_callback2.GetResult(server_ret2);
681
682 ASSERT_THAT(client_ret2, IsOk());
683 ASSERT_THAT(server_ret2, IsOk());
684
685 // Make sure the cert status is expected.
686 SSLInfo ssl_info2;
687 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
688 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
689 SSLInfo ssl_server_info2;
690 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
691 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
692 }
693
694 // Client certificates are disabled on iOS.
695 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
696 // This test executes Connect() on SSLClientSocket and Handshake() on
697 // SSLServerSocket to make sure handshaking between the two sockets is
698 // completed successfully, using client certificate.
TEST_F(SSLServerSocketTest,HandshakeWithClientCert)699 TEST_F(SSLServerSocketTest, HandshakeWithClientCert) {
700 scoped_refptr<X509Certificate> client_cert =
701 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
702 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
703 kClientCertFileName, kClientPrivateKeyFileName));
704 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
705 ASSERT_NO_FATAL_FAILURE(CreateContext());
706 ASSERT_NO_FATAL_FAILURE(CreateSockets());
707
708 TestCompletionCallback handshake_callback;
709 int server_ret = server_socket_->Handshake(handshake_callback.callback());
710
711 TestCompletionCallback connect_callback;
712 int client_ret = client_socket_->Connect(connect_callback.callback());
713
714 client_ret = connect_callback.GetResult(client_ret);
715 server_ret = handshake_callback.GetResult(server_ret);
716
717 ASSERT_THAT(client_ret, IsOk());
718 ASSERT_THAT(server_ret, IsOk());
719
720 // Make sure the cert status is expected.
721 SSLInfo ssl_info;
722 client_socket_->GetSSLInfo(&ssl_info);
723 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
724 server_socket_->GetSSLInfo(&ssl_info);
725 ASSERT_TRUE(ssl_info.cert.get());
726 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_info.cert.get()));
727 }
728
729 // This test executes Connect() on SSLClientSocket and Handshake() twice on
730 // SSLServerSocket to make sure handshaking between the two sockets is
731 // completed successfully, using client certificate. The second connection is
732 // expected to succeed through the session cache.
TEST_F(SSLServerSocketTest,HandshakeWithClientCertCached)733 TEST_F(SSLServerSocketTest, HandshakeWithClientCertCached) {
734 scoped_refptr<X509Certificate> client_cert =
735 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
736 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
737 kClientCertFileName, kClientPrivateKeyFileName));
738 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
739 ASSERT_NO_FATAL_FAILURE(CreateContext());
740 ASSERT_NO_FATAL_FAILURE(CreateSockets());
741
742 TestCompletionCallback handshake_callback;
743 int server_ret = server_socket_->Handshake(handshake_callback.callback());
744
745 TestCompletionCallback connect_callback;
746 int client_ret = client_socket_->Connect(connect_callback.callback());
747
748 client_ret = connect_callback.GetResult(client_ret);
749 server_ret = handshake_callback.GetResult(server_ret);
750
751 ASSERT_THAT(client_ret, IsOk());
752 ASSERT_THAT(server_ret, IsOk());
753
754 // Make sure the cert status is expected.
755 SSLInfo ssl_info;
756 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
757 EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
758 SSLInfo ssl_server_info;
759 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
760 ASSERT_TRUE(ssl_server_info.cert.get());
761 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info.cert.get()));
762 EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
763 // Pump client read to get new session tickets.
764 PumpServerToClient();
765 server_socket_->Disconnect();
766 client_socket_->Disconnect();
767
768 // Create the connection again.
769 ASSERT_NO_FATAL_FAILURE(CreateSockets());
770 TestCompletionCallback handshake_callback2;
771 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
772
773 TestCompletionCallback connect_callback2;
774 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
775
776 client_ret2 = connect_callback2.GetResult(client_ret2);
777 server_ret2 = handshake_callback2.GetResult(server_ret2);
778
779 ASSERT_THAT(client_ret2, IsOk());
780 ASSERT_THAT(server_ret2, IsOk());
781
782 // Make sure the cert status is expected.
783 SSLInfo ssl_info2;
784 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
785 EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
786 SSLInfo ssl_server_info2;
787 ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
788 ASSERT_TRUE(ssl_server_info2.cert.get());
789 EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info2.cert.get()));
790 EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
791 }
792
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSupplied)793 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) {
794 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
795 ASSERT_NO_FATAL_FAILURE(CreateContext());
796 ASSERT_NO_FATAL_FAILURE(CreateSockets());
797 // Use the default setting for the client socket, which is to not send
798 // a client certificate. This will cause the client to receive an
799 // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
800 // requested cert_authorities from the CertificateRequest sent by the
801 // server.
802
803 TestCompletionCallback handshake_callback;
804 int server_ret = server_socket_->Handshake(handshake_callback.callback());
805
806 TestCompletionCallback connect_callback;
807 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
808 connect_callback.GetResult(
809 client_socket_->Connect(connect_callback.callback())));
810
811 auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
812 client_socket_->GetSSLCertRequestInfo(request_info.get());
813
814 // Check that the authority name that arrived in the CertificateRequest
815 // handshake message is as expected.
816 scoped_refptr<X509Certificate> client_cert =
817 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
818 ASSERT_TRUE(client_cert);
819 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
820
821 client_socket_->Disconnect();
822
823 EXPECT_THAT(handshake_callback.GetResult(server_ret),
824 IsError(ERR_CONNECTION_CLOSED));
825 }
826
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSuppliedCached)827 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSuppliedCached) {
828 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
829 ASSERT_NO_FATAL_FAILURE(CreateContext());
830 ASSERT_NO_FATAL_FAILURE(CreateSockets());
831 // Use the default setting for the client socket, which is to not send
832 // a client certificate. This will cause the client to receive an
833 // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
834 // requested cert_authorities from the CertificateRequest sent by the
835 // server.
836
837 TestCompletionCallback handshake_callback;
838 int server_ret = server_socket_->Handshake(handshake_callback.callback());
839
840 TestCompletionCallback connect_callback;
841 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
842 connect_callback.GetResult(
843 client_socket_->Connect(connect_callback.callback())));
844
845 auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
846 client_socket_->GetSSLCertRequestInfo(request_info.get());
847
848 // Check that the authority name that arrived in the CertificateRequest
849 // handshake message is as expected.
850 scoped_refptr<X509Certificate> client_cert =
851 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
852 ASSERT_TRUE(client_cert);
853 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
854
855 client_socket_->Disconnect();
856
857 EXPECT_THAT(handshake_callback.GetResult(server_ret),
858 IsError(ERR_CONNECTION_CLOSED));
859 server_socket_->Disconnect();
860
861 // Below, check that the cache didn't store the result of a failed handshake.
862 ASSERT_NO_FATAL_FAILURE(CreateSockets());
863 TestCompletionCallback handshake_callback2;
864 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
865
866 TestCompletionCallback connect_callback2;
867 EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
868 connect_callback2.GetResult(
869 client_socket_->Connect(connect_callback2.callback())));
870
871 auto request_info2 = base::MakeRefCounted<SSLCertRequestInfo>();
872 client_socket_->GetSSLCertRequestInfo(request_info2.get());
873
874 // Check that the authority name that arrived in the CertificateRequest
875 // handshake message is as expected.
876 EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info2->cert_authorities));
877
878 client_socket_->Disconnect();
879
880 EXPECT_THAT(handshake_callback2.GetResult(server_ret2),
881 IsError(ERR_CONNECTION_CLOSED));
882 }
883
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSupplied)884 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) {
885 scoped_refptr<X509Certificate> client_cert =
886 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
887 ASSERT_TRUE(client_cert);
888
889 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
890 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
891 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
892 ASSERT_NO_FATAL_FAILURE(CreateContext());
893 ASSERT_NO_FATAL_FAILURE(CreateSockets());
894
895 TestCompletionCallback handshake_callback;
896 int server_ret = server_socket_->Handshake(handshake_callback.callback());
897
898 TestCompletionCallback connect_callback;
899 int client_ret = client_socket_->Connect(connect_callback.callback());
900
901 // In TLS 1.3, the client cert error isn't exposed until Read is called.
902 EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
903 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
904 handshake_callback.GetResult(server_ret));
905
906 // Pump client read to get client cert error.
907 const int kReadBufSize = 1024;
908 scoped_refptr<DrainableIOBuffer> read_buf =
909 base::MakeRefCounted<DrainableIOBuffer>(
910 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
911 TestCompletionCallback read_callback;
912 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
913 read_callback.callback());
914 client_ret = read_callback.GetResult(client_ret);
915 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
916 }
917
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedTLS12)918 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedTLS12) {
919 scoped_refptr<X509Certificate> client_cert =
920 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
921 ASSERT_TRUE(client_cert);
922
923 client_ssl_config_.version_max_override = SSL_PROTOCOL_VERSION_TLS1_2;
924 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
925 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
926 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
927 ASSERT_NO_FATAL_FAILURE(CreateContext());
928 ASSERT_NO_FATAL_FAILURE(CreateSockets());
929
930 TestCompletionCallback handshake_callback;
931 int server_ret = server_socket_->Handshake(handshake_callback.callback());
932
933 TestCompletionCallback connect_callback;
934 int client_ret = client_socket_->Connect(connect_callback.callback());
935
936 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
937 connect_callback.GetResult(client_ret));
938 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
939 handshake_callback.GetResult(server_ret));
940 }
941
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedCached)942 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) {
943 scoped_refptr<X509Certificate> client_cert =
944 ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
945 ASSERT_TRUE(client_cert);
946
947 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
948 kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
949 ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
950 ASSERT_NO_FATAL_FAILURE(CreateContext());
951 ASSERT_NO_FATAL_FAILURE(CreateSockets());
952
953 TestCompletionCallback handshake_callback;
954 int server_ret = server_socket_->Handshake(handshake_callback.callback());
955
956 TestCompletionCallback connect_callback;
957 int client_ret = client_socket_->Connect(connect_callback.callback());
958
959 // In TLS 1.3, the client cert error isn't exposed until Read is called.
960 EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
961 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
962 handshake_callback.GetResult(server_ret));
963
964 // Pump client read to get client cert error.
965 const int kReadBufSize = 1024;
966 scoped_refptr<DrainableIOBuffer> read_buf =
967 base::MakeRefCounted<DrainableIOBuffer>(
968 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
969 TestCompletionCallback read_callback;
970 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
971 read_callback.callback());
972 client_ret = read_callback.GetResult(client_ret);
973 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
974
975 client_socket_->Disconnect();
976 server_socket_->Disconnect();
977
978 // Below, check that the cache didn't store the result of a failed handshake.
979 ASSERT_NO_FATAL_FAILURE(CreateSockets());
980 TestCompletionCallback handshake_callback2;
981 int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
982
983 TestCompletionCallback connect_callback2;
984 int client_ret2 = client_socket_->Connect(connect_callback2.callback());
985
986 // In TLS 1.3, the client cert error isn't exposed until Read is called.
987 EXPECT_EQ(OK, connect_callback2.GetResult(client_ret2));
988 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
989 handshake_callback2.GetResult(server_ret2));
990
991 // Pump client read to get client cert error.
992 client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
993 read_callback.callback());
994 client_ret = read_callback.GetResult(client_ret);
995 EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
996 }
997 #endif // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
998
TEST_P(SSLServerSocketReadTest,DataTransfer)999 TEST_P(SSLServerSocketReadTest, DataTransfer) {
1000 ASSERT_NO_FATAL_FAILURE(CreateContext());
1001 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1002
1003 // Establish connection.
1004 TestCompletionCallback connect_callback;
1005 int client_ret = client_socket_->Connect(connect_callback.callback());
1006 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1007
1008 TestCompletionCallback handshake_callback;
1009 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1010 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1011
1012 client_ret = connect_callback.GetResult(client_ret);
1013 ASSERT_THAT(client_ret, IsOk());
1014 server_ret = handshake_callback.GetResult(server_ret);
1015 ASSERT_THAT(server_ret, IsOk());
1016
1017 const int kReadBufSize = 1024;
1018 scoped_refptr<StringIOBuffer> write_buf =
1019 base::MakeRefCounted<StringIOBuffer>("testing123");
1020 scoped_refptr<DrainableIOBuffer> read_buf =
1021 base::MakeRefCounted<DrainableIOBuffer>(
1022 base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
1023
1024 // Write then read.
1025 TestCompletionCallback write_callback;
1026 TestCompletionCallback read_callback;
1027 server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1028 write_callback.callback(),
1029 TRAFFIC_ANNOTATION_FOR_TESTS);
1030 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1031 client_ret = client_socket_->Read(
1032 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1033 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1034
1035 server_ret = write_callback.GetResult(server_ret);
1036 EXPECT_GT(server_ret, 0);
1037 client_ret = read_callback.GetResult(client_ret);
1038 ASSERT_GT(client_ret, 0);
1039
1040 read_buf->DidConsume(client_ret);
1041 while (read_buf->BytesConsumed() < write_buf->size()) {
1042 client_ret = client_socket_->Read(
1043 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1044 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1045 client_ret = read_callback.GetResult(client_ret);
1046 ASSERT_GT(client_ret, 0);
1047 read_buf->DidConsume(client_ret);
1048 }
1049 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1050 read_buf->SetOffset(0);
1051 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1052
1053 // Read then write.
1054 write_buf = base::MakeRefCounted<StringIOBuffer>("hello123");
1055 server_ret = Read(server_socket_.get(), read_buf.get(),
1056 read_buf->BytesRemaining(), read_callback.callback());
1057 EXPECT_EQ(server_ret, ERR_IO_PENDING);
1058 client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1059 write_callback.callback(),
1060 TRAFFIC_ANNOTATION_FOR_TESTS);
1061 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1062
1063 server_ret = read_callback.GetResult(server_ret);
1064 if (read_if_ready_enabled()) {
1065 // ReadIfReady signals the data is available but does not consume it.
1066 // The data is consumed later below.
1067 ASSERT_EQ(server_ret, OK);
1068 } else {
1069 ASSERT_GT(server_ret, 0);
1070 read_buf->DidConsume(server_ret);
1071 }
1072 client_ret = write_callback.GetResult(client_ret);
1073 EXPECT_GT(client_ret, 0);
1074
1075 while (read_buf->BytesConsumed() < write_buf->size()) {
1076 server_ret = Read(server_socket_.get(), read_buf.get(),
1077 read_buf->BytesRemaining(), read_callback.callback());
1078 // All the data was written above, so the data should be synchronously
1079 // available out of both Read() and ReadIfReady().
1080 ASSERT_GT(server_ret, 0);
1081 read_buf->DidConsume(server_ret);
1082 }
1083 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1084 read_buf->SetOffset(0);
1085 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1086 }
1087
1088 // A regression test for bug 127822 (http://crbug.com/127822).
1089 // If the server closes the connection after the handshake is finished,
1090 // the client's Write() call should not cause an infinite loop.
1091 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
TEST_F(SSLServerSocketTest,ClientWriteAfterServerClose)1092 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
1093 ASSERT_NO_FATAL_FAILURE(CreateContext());
1094 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1095
1096 // Establish connection.
1097 TestCompletionCallback connect_callback;
1098 int client_ret = client_socket_->Connect(connect_callback.callback());
1099 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1100
1101 TestCompletionCallback handshake_callback;
1102 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1103 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1104
1105 client_ret = connect_callback.GetResult(client_ret);
1106 ASSERT_THAT(client_ret, IsOk());
1107 server_ret = handshake_callback.GetResult(server_ret);
1108 ASSERT_THAT(server_ret, IsOk());
1109
1110 scoped_refptr<StringIOBuffer> write_buf =
1111 base::MakeRefCounted<StringIOBuffer>("testing123");
1112
1113 // The server closes the connection. The server needs to write some
1114 // data first so that the client's Read() calls from the transport
1115 // socket won't return ERR_IO_PENDING. This ensures that the client
1116 // will call Read() on the transport socket again.
1117 TestCompletionCallback write_callback;
1118 server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1119 write_callback.callback(),
1120 TRAFFIC_ANNOTATION_FOR_TESTS);
1121 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1122
1123 server_ret = write_callback.GetResult(server_ret);
1124 EXPECT_GT(server_ret, 0);
1125
1126 server_socket_->Disconnect();
1127
1128 // The client writes some data. This should not cause an infinite loop.
1129 client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1130 write_callback.callback(),
1131 TRAFFIC_ANNOTATION_FOR_TESTS);
1132 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1133
1134 client_ret = write_callback.GetResult(client_ret);
1135 EXPECT_GT(client_ret, 0);
1136
1137 base::RunLoop run_loop;
1138 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
1139 FROM_HERE, run_loop.QuitClosure(), base::Milliseconds(10));
1140 run_loop.Run();
1141 }
1142
1143 // This test executes ExportKeyingMaterial() on the client and server sockets,
1144 // after connecting them, and verifies that the results match.
1145 // This test will fail if False Start is enabled (see crbug.com/90208).
TEST_F(SSLServerSocketTest,ExportKeyingMaterial)1146 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
1147 ASSERT_NO_FATAL_FAILURE(CreateContext());
1148 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1149
1150 TestCompletionCallback connect_callback;
1151 int client_ret = client_socket_->Connect(connect_callback.callback());
1152 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1153
1154 TestCompletionCallback handshake_callback;
1155 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1156 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1157
1158 if (client_ret == ERR_IO_PENDING) {
1159 ASSERT_THAT(connect_callback.WaitForResult(), IsOk());
1160 }
1161 if (server_ret == ERR_IO_PENDING) {
1162 ASSERT_THAT(handshake_callback.WaitForResult(), IsOk());
1163 }
1164
1165 const int kKeyingMaterialSize = 32;
1166 const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test";
1167 const char kKeyingContext[] = "";
1168 unsigned char server_out[kKeyingMaterialSize];
1169 int rv = server_socket_->ExportKeyingMaterial(
1170 kKeyingLabel, false, kKeyingContext, server_out, sizeof(server_out));
1171 ASSERT_THAT(rv, IsOk());
1172
1173 unsigned char client_out[kKeyingMaterialSize];
1174 rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, false, kKeyingContext,
1175 client_out, sizeof(client_out));
1176 ASSERT_THAT(rv, IsOk());
1177 EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
1178
1179 const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad";
1180 unsigned char client_bad[kKeyingMaterialSize];
1181 rv = client_socket_->ExportKeyingMaterial(
1182 kKeyingLabelBad, false, kKeyingContext, client_bad, sizeof(client_bad));
1183 ASSERT_EQ(rv, OK);
1184 EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
1185 }
1186
1187 // Verifies that SSLConfig::require_ecdhe flags works properly.
TEST_F(SSLServerSocketTest,RequireEcdheFlag)1188 TEST_F(SSLServerSocketTest, RequireEcdheFlag) {
1189 // Disable all ECDHE suites on the client side.
1190 SSLContextConfig config;
1191 config.disabled_cipher_suites.assign(
1192 kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1193
1194 // Legacy RSA key exchange ciphers only exist in TLS 1.2 and below.
1195 config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1196 ssl_config_service_->UpdateSSLConfigAndNotify(config);
1197
1198 // Require ECDHE on the server.
1199 server_ssl_config_.require_ecdhe = true;
1200
1201 ASSERT_NO_FATAL_FAILURE(CreateContext());
1202 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1203
1204 TestCompletionCallback connect_callback;
1205 int client_ret = client_socket_->Connect(connect_callback.callback());
1206
1207 TestCompletionCallback handshake_callback;
1208 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1209
1210 client_ret = connect_callback.GetResult(client_ret);
1211 server_ret = handshake_callback.GetResult(server_ret);
1212
1213 ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1214 ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1215 }
1216
1217 // This test executes Connect() on SSLClientSocket and Handshake() on
1218 // SSLServerSocket to make sure handshaking between the two sockets is
1219 // completed successfully. The server key is represented by SSLPrivateKey.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKey)1220 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) {
1221 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1222 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1223
1224 TestCompletionCallback handshake_callback;
1225 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1226
1227 TestCompletionCallback connect_callback;
1228 int client_ret = client_socket_->Connect(connect_callback.callback());
1229
1230 client_ret = connect_callback.GetResult(client_ret);
1231 server_ret = handshake_callback.GetResult(server_ret);
1232
1233 ASSERT_THAT(client_ret, IsOk());
1234 ASSERT_THAT(server_ret, IsOk());
1235
1236 // Make sure the cert status is expected.
1237 SSLInfo ssl_info;
1238 ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
1239 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
1240
1241 // The default cipher suite should be ECDHE and an AEAD.
1242 uint16_t cipher_suite =
1243 SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
1244 const char* key_exchange;
1245 const char* cipher;
1246 const char* mac;
1247 bool is_aead;
1248 bool is_tls13;
1249 SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
1250 cipher_suite);
1251 EXPECT_TRUE(is_aead);
1252 }
1253
1254 namespace {
1255
1256 // Helper that wraps an underlying SSLPrivateKey to allow the test to
1257 // do some work immediately before a `Sign()` operation is performed.
1258 class SSLPrivateKeyHook : public SSLPrivateKey {
1259 public:
SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,base::RepeatingClosure on_sign)1260 SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,
1261 base::RepeatingClosure on_sign)
1262 : private_key_(std::move(private_key)), on_sign_(std::move(on_sign)) {}
1263
1264 // SSLPrivateKey implementation.
GetProviderName()1265 std::string GetProviderName() override {
1266 return private_key_->GetProviderName();
1267 }
GetAlgorithmPreferences()1268 std::vector<uint16_t> GetAlgorithmPreferences() override {
1269 return private_key_->GetAlgorithmPreferences();
1270 }
Sign(uint16_t algorithm,base::span<const uint8_t> input,SignCallback callback)1271 void Sign(uint16_t algorithm,
1272 base::span<const uint8_t> input,
1273 SignCallback callback) override {
1274 on_sign_.Run();
1275 private_key_->Sign(algorithm, input, std::move(callback));
1276 }
1277
1278 private:
1279 ~SSLPrivateKeyHook() override = default;
1280
1281 const scoped_refptr<SSLPrivateKey> private_key_;
1282 const base::RepeatingClosure on_sign_;
1283 };
1284
1285 } // namespace
1286
1287 // Verifies that if the client disconnects while during private key signing then
1288 // the disconnection is correctly reported to the `Handshake()` completion
1289 // callback, with `ERR_CONNECTION_CLOSED`.
1290 // This is a regression test for crbug.com/1449461.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError)1291 TEST_F(SSLServerSocketTest,
1292 HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError) {
1293 auto on_sign = base::BindLambdaForTesting([&]() {
1294 client_socket_->Disconnect();
1295 ASSERT_FALSE(client_socket_->IsConnected());
1296 });
1297 server_ssl_private_key_ = base::MakeRefCounted<SSLPrivateKeyHook>(
1298 std::move(server_ssl_private_key_), on_sign);
1299 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1300 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1301
1302 TestCompletionCallback handshake_callback;
1303 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1304 ASSERT_EQ(server_ret, net::ERR_IO_PENDING);
1305
1306 TestCompletionCallback connect_callback;
1307 client_socket_->Connect(connect_callback.callback());
1308
1309 // If resuming the handshake after private-key signing is not handled
1310 // correctly as per crbug.com/1449461 then the test will hang and timeout
1311 // at this point, due to the server-side completion callback not being
1312 // correctly invoked.
1313 server_ret = handshake_callback.GetResult(server_ret);
1314 EXPECT_EQ(server_ret, net::ERR_CONNECTION_CLOSED);
1315 }
1316
1317 // Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the
1318 // server key.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyRequireEcdhe)1319 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) {
1320 // Disable all ECDHE suites on the client side.
1321 SSLContextConfig config;
1322 config.disabled_cipher_suites.assign(
1323 kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1324 // TLS 1.3 always works with SSLPrivateKey.
1325 config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1326 ssl_config_service_->UpdateSSLConfigAndNotify(config);
1327
1328 ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1329 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1330
1331 TestCompletionCallback connect_callback;
1332 int client_ret = client_socket_->Connect(connect_callback.callback());
1333
1334 TestCompletionCallback handshake_callback;
1335 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1336
1337 client_ret = connect_callback.GetResult(client_ret);
1338 server_ret = handshake_callback.GetResult(server_ret);
1339
1340 ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1341 ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1342 }
1343
1344 class SSLServerSocketAlpsTest
1345 : public SSLServerSocketTest,
1346 public ::testing::WithParamInterface<std::tuple<bool, bool>> {
1347 public:
SSLServerSocketAlpsTest()1348 SSLServerSocketAlpsTest()
1349 : client_alps_enabled_(std::get<0>(GetParam())),
1350 server_alps_enabled_(std::get<1>(GetParam())) {}
1351 ~SSLServerSocketAlpsTest() override = default;
1352 const bool client_alps_enabled_;
1353 const bool server_alps_enabled_;
1354 };
1355
1356 INSTANTIATE_TEST_SUITE_P(All,
1357 SSLServerSocketAlpsTest,
1358 ::testing::Combine(::testing::Bool(),
1359 ::testing::Bool()));
1360
TEST_P(SSLServerSocketAlpsTest,Alps)1361 TEST_P(SSLServerSocketAlpsTest, Alps) {
1362 const std::string server_data = "server sends some test data";
1363 const std::string client_data = "client also sends some data";
1364
1365 server_ssl_config_.alpn_protos = {kProtoHTTP2};
1366 if (server_alps_enabled_) {
1367 server_ssl_config_.application_settings[kProtoHTTP2] =
1368 std::vector<uint8_t>(server_data.begin(), server_data.end());
1369 }
1370
1371 client_ssl_config_.alpn_protos = {kProtoHTTP2};
1372 if (client_alps_enabled_) {
1373 client_ssl_config_.application_settings[kProtoHTTP2] =
1374 std::vector<uint8_t>(client_data.begin(), client_data.end());
1375 }
1376
1377 ASSERT_NO_FATAL_FAILURE(CreateContext());
1378 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1379
1380 TestCompletionCallback handshake_callback;
1381 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1382
1383 TestCompletionCallback connect_callback;
1384 int client_ret = client_socket_->Connect(connect_callback.callback());
1385
1386 client_ret = connect_callback.GetResult(client_ret);
1387 server_ret = handshake_callback.GetResult(server_ret);
1388
1389 ASSERT_THAT(client_ret, IsOk());
1390 ASSERT_THAT(server_ret, IsOk());
1391
1392 // ALPS is negotiated only if ALPS is enabled both on client and server.
1393 const auto alps_data_received_by_client =
1394 client_socket_->GetPeerApplicationSettings();
1395 const auto alps_data_received_by_server =
1396 server_socket_->GetPeerApplicationSettings();
1397
1398 if (client_alps_enabled_ && server_alps_enabled_) {
1399 ASSERT_TRUE(alps_data_received_by_client.has_value());
1400 EXPECT_EQ(server_data, alps_data_received_by_client.value());
1401 ASSERT_TRUE(alps_data_received_by_server.has_value());
1402 EXPECT_EQ(client_data, alps_data_received_by_server.value());
1403 } else {
1404 EXPECT_FALSE(alps_data_received_by_client.has_value());
1405 EXPECT_FALSE(alps_data_received_by_server.has_value());
1406 }
1407 }
1408
1409 // Test that CancelReadIfReady works.
TEST_F(SSLServerSocketTest,CancelReadIfReady)1410 TEST_F(SSLServerSocketTest, CancelReadIfReady) {
1411 ASSERT_NO_FATAL_FAILURE(CreateContext());
1412 ASSERT_NO_FATAL_FAILURE(CreateSockets());
1413
1414 TestCompletionCallback connect_callback;
1415 int client_ret = client_socket_->Connect(connect_callback.callback());
1416 TestCompletionCallback handshake_callback;
1417 int server_ret = server_socket_->Handshake(handshake_callback.callback());
1418 ASSERT_THAT(connect_callback.GetResult(client_ret), IsOk());
1419 ASSERT_THAT(handshake_callback.GetResult(server_ret), IsOk());
1420
1421 // Attempt to read from the server socket. There will not be anything to read.
1422 // Cancel the read immediately afterwards.
1423 TestCompletionCallback read_callback;
1424 auto read_buf = base::MakeRefCounted<IOBufferWithSize>(1);
1425 int read_ret =
1426 server_socket_->ReadIfReady(read_buf.get(), 1, read_callback.callback());
1427 ASSERT_THAT(read_ret, IsError(ERR_IO_PENDING));
1428 ASSERT_THAT(server_socket_->CancelReadIfReady(), IsOk());
1429
1430 // After the client writes data, the server should still not pick up a result.
1431 auto write_buf = base::MakeRefCounted<StringIOBuffer>("a");
1432 TestCompletionCallback write_callback;
1433 ASSERT_EQ(write_callback.GetResult(client_socket_->Write(
1434 write_buf.get(), write_buf->size(), write_callback.callback(),
1435 TRAFFIC_ANNOTATION_FOR_TESTS)),
1436 write_buf->size());
1437 base::RunLoop().RunUntilIdle();
1438 EXPECT_FALSE(read_callback.have_result());
1439
1440 // After a canceled read, future reads are still possible.
1441 while (true) {
1442 TestCompletionCallback read_callback2;
1443 read_ret = server_socket_->ReadIfReady(read_buf.get(), 1,
1444 read_callback2.callback());
1445 if (read_ret != ERR_IO_PENDING) {
1446 break;
1447 }
1448 ASSERT_THAT(read_callback2.GetResult(read_ret), IsOk());
1449 }
1450 ASSERT_EQ(1, read_ret);
1451 EXPECT_EQ(read_buf->data()[0], 'a');
1452 }
1453
1454 } // namespace net
1455