xref: /aosp_15_r20/external/cronet/net/socket/ssl_server_socket_unittest.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 // This test suite uses SSLClientSocket to test the implementation of
6 // SSLServerSocket. In order to establish connections between the sockets
7 // we need two additional classes:
8 // 1. FakeSocket
9 //    Connects SSL socket to FakeDataChannel. This class is just a stub.
10 //
11 // 2. FakeDataChannel
12 //    Implements the actual exchange of data between two FakeSockets.
13 //
14 // Implementations of these two classes are included in this file.
15 
16 #include "net/socket/ssl_server_socket.h"
17 
18 #include <stdint.h>
19 #include <stdlib.h>
20 
21 #include <memory>
22 #include <string_view>
23 #include <utility>
24 
25 #include "base/check.h"
26 #include "base/compiler_specific.h"
27 #include "base/containers/queue.h"
28 #include "base/files/file_path.h"
29 #include "base/files/file_util.h"
30 #include "base/functional/bind.h"
31 #include "base/functional/callback_helpers.h"
32 #include "base/location.h"
33 #include "base/memory/raw_ptr.h"
34 #include "base/memory/scoped_refptr.h"
35 #include "base/notreached.h"
36 #include "base/run_loop.h"
37 #include "base/task/single_thread_task_runner.h"
38 #include "base/test/bind.h"
39 #include "base/test/task_environment.h"
40 #include "build/build_config.h"
41 #include "crypto/rsa_private_key.h"
42 #include "net/base/address_list.h"
43 #include "net/base/completion_once_callback.h"
44 #include "net/base/host_port_pair.h"
45 #include "net/base/io_buffer.h"
46 #include "net/base/ip_address.h"
47 #include "net/base/ip_endpoint.h"
48 #include "net/base/net_errors.h"
49 #include "net/cert/cert_status_flags.h"
50 #include "net/cert/mock_cert_verifier.h"
51 #include "net/cert/mock_client_cert_verifier.h"
52 #include "net/cert/signed_certificate_timestamp_and_status.h"
53 #include "net/cert/x509_certificate.h"
54 #include "net/http/transport_security_state.h"
55 #include "net/log/net_log_with_source.h"
56 #include "net/socket/client_socket_factory.h"
57 #include "net/socket/socket_test_util.h"
58 #include "net/socket/ssl_client_socket.h"
59 #include "net/socket/stream_socket.h"
60 #include "net/ssl/openssl_private_key.h"
61 #include "net/ssl/ssl_cert_request_info.h"
62 #include "net/ssl/ssl_cipher_suite_names.h"
63 #include "net/ssl/ssl_client_session_cache.h"
64 #include "net/ssl/ssl_connection_status_flags.h"
65 #include "net/ssl/ssl_info.h"
66 #include "net/ssl/ssl_private_key.h"
67 #include "net/ssl/ssl_server_config.h"
68 #include "net/ssl/test_ssl_config_service.h"
69 #include "net/test/cert_test_util.h"
70 #include "net/test/gtest_util.h"
71 #include "net/test/test_data_directory.h"
72 #include "net/test/test_with_task_environment.h"
73 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
74 #include "testing/gmock/include/gmock/gmock.h"
75 #include "testing/gtest/include/gtest/gtest.h"
76 #include "testing/platform_test.h"
77 #include "third_party/boringssl/src/include/openssl/evp.h"
78 #include "third_party/boringssl/src/include/openssl/ssl.h"
79 
80 using net::test::IsError;
81 using net::test::IsOk;
82 
83 namespace net {
84 
85 namespace {
86 
87 // Client certificates are disabled on iOS.
88 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
89 const char kClientCertFileName[] = "client_1.pem";
90 const char kClientPrivateKeyFileName[] = "client_1.pk8";
91 const char kWrongClientCertFileName[] = "client_2.pem";
92 const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
93 #endif  // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
94 
95 const uint16_t kEcdheCiphers[] = {
96     0xc007,  // ECDHE_ECDSA_WITH_RC4_128_SHA
97     0xc009,  // ECDHE_ECDSA_WITH_AES_128_CBC_SHA
98     0xc00a,  // ECDHE_ECDSA_WITH_AES_256_CBC_SHA
99     0xc011,  // ECDHE_RSA_WITH_RC4_128_SHA
100     0xc013,  // ECDHE_RSA_WITH_AES_128_CBC_SHA
101     0xc014,  // ECDHE_RSA_WITH_AES_256_CBC_SHA
102     0xc02b,  // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
103     0xc02c,  // ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
104     0xc02f,  // ECDHE_RSA_WITH_AES_128_GCM_SHA256
105     0xc030,  // ECDHE_RSA_WITH_AES_256_GCM_SHA384
106     0xcca8,  // ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
107     0xcca9,  // ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
108 };
109 
110 class FakeDataChannel {
111  public:
112   FakeDataChannel() = default;
113 
114   FakeDataChannel(const FakeDataChannel&) = delete;
115   FakeDataChannel& operator=(const FakeDataChannel&) = delete;
116 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)117   int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback) {
118     DCHECK(read_callback_.is_null());
119     DCHECK(!read_buf_.get());
120     if (closed_)
121       return 0;
122     if (data_.empty()) {
123       read_callback_ = std::move(callback);
124       read_buf_ = buf;
125       read_buf_len_ = buf_len;
126       return ERR_IO_PENDING;
127     }
128     return PropagateData(buf, buf_len);
129   }
130 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)131   int Write(IOBuffer* buf,
132             int buf_len,
133             CompletionOnceCallback callback,
134             const NetworkTrafficAnnotationTag& traffic_annotation) {
135     DCHECK(write_callback_.is_null());
136     if (closed_) {
137       if (write_called_after_close_)
138         return ERR_CONNECTION_RESET;
139       write_called_after_close_ = true;
140       write_callback_ = std::move(callback);
141       base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
142           FROM_HERE, base::BindOnce(&FakeDataChannel::DoWriteCallback,
143                                     weak_factory_.GetWeakPtr()));
144       return ERR_IO_PENDING;
145     }
146     // This function returns synchronously, so make a copy of the buffer.
147     data_.push(base::MakeRefCounted<DrainableIOBuffer>(
148         base::MakeRefCounted<StringIOBuffer>(std::string(buf->data(), buf_len)),
149         buf_len));
150     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
151         FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
152                                   weak_factory_.GetWeakPtr()));
153     return buf_len;
154   }
155 
156   // Closes the FakeDataChannel. After Close() is called, Read() returns 0,
157   // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that
158   // after the FakeDataChannel is closed, the first Write() call completes
159   // asynchronously, which is necessary to reproduce bug 127822.
Close()160   void Close() {
161     closed_ = true;
162     if (!read_callback_.is_null()) {
163       base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
164           FROM_HERE, base::BindOnce(&FakeDataChannel::DoReadCallback,
165                                     weak_factory_.GetWeakPtr()));
166     }
167   }
168 
169  private:
DoReadCallback()170   void DoReadCallback() {
171     if (read_callback_.is_null())
172       return;
173 
174     if (closed_) {
175       std::move(read_callback_).Run(ERR_CONNECTION_CLOSED);
176       return;
177     }
178 
179     if (data_.empty())
180       return;
181 
182     int copied = PropagateData(read_buf_, read_buf_len_);
183     read_buf_ = nullptr;
184     read_buf_len_ = 0;
185     std::move(read_callback_).Run(copied);
186   }
187 
DoWriteCallback()188   void DoWriteCallback() {
189     if (write_callback_.is_null())
190       return;
191 
192     std::move(write_callback_).Run(ERR_CONNECTION_RESET);
193   }
194 
PropagateData(scoped_refptr<IOBuffer> read_buf,int read_buf_len)195   int PropagateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) {
196     scoped_refptr<DrainableIOBuffer> buf = data_.front();
197     int copied = std::min(buf->BytesRemaining(), read_buf_len);
198     memcpy(read_buf->data(), buf->data(), copied);
199     buf->DidConsume(copied);
200 
201     if (!buf->BytesRemaining())
202       data_.pop();
203     return copied;
204   }
205 
206   CompletionOnceCallback read_callback_;
207   scoped_refptr<IOBuffer> read_buf_;
208   int read_buf_len_ = 0;
209 
210   CompletionOnceCallback write_callback_;
211 
212   base::queue<scoped_refptr<DrainableIOBuffer>> data_;
213 
214   // True if Close() has been called.
215   bool closed_ = false;
216 
217   // Controls the completion of Write() after the FakeDataChannel is closed.
218   // After the FakeDataChannel is closed, the first Write() call completes
219   // asynchronously.
220   bool write_called_after_close_ = false;
221 
222   base::WeakPtrFactory<FakeDataChannel> weak_factory_{this};
223 };
224 
225 class FakeSocket : public StreamSocket {
226  public:
FakeSocket(FakeDataChannel * incoming_channel,FakeDataChannel * outgoing_channel)227   FakeSocket(FakeDataChannel* incoming_channel,
228              FakeDataChannel* outgoing_channel)
229       : incoming_(incoming_channel), outgoing_(outgoing_channel) {}
230 
231   FakeSocket(const FakeSocket&) = delete;
232   FakeSocket& operator=(const FakeSocket&) = delete;
233 
234   ~FakeSocket() override = default;
235 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)236   int Read(IOBuffer* buf,
237            int buf_len,
238            CompletionOnceCallback callback) override {
239     // Read random number of bytes.
240     buf_len = rand() % buf_len + 1;
241     return incoming_->Read(buf, buf_len, std::move(callback));
242   }
243 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)244   int Write(IOBuffer* buf,
245             int buf_len,
246             CompletionOnceCallback callback,
247             const NetworkTrafficAnnotationTag& traffic_annotation) override {
248     // Write random number of bytes.
249     buf_len = rand() % buf_len + 1;
250     return outgoing_->Write(buf, buf_len, std::move(callback),
251                             TRAFFIC_ANNOTATION_FOR_TESTS);
252   }
253 
SetReceiveBufferSize(int32_t size)254   int SetReceiveBufferSize(int32_t size) override { return OK; }
255 
SetSendBufferSize(int32_t size)256   int SetSendBufferSize(int32_t size) override { return OK; }
257 
Connect(CompletionOnceCallback callback)258   int Connect(CompletionOnceCallback callback) override { return OK; }
259 
Disconnect()260   void Disconnect() override {
261     incoming_->Close();
262     outgoing_->Close();
263   }
264 
IsConnected() const265   bool IsConnected() const override { return true; }
266 
IsConnectedAndIdle() const267   bool IsConnectedAndIdle() const override { return true; }
268 
GetPeerAddress(IPEndPoint * address) const269   int GetPeerAddress(IPEndPoint* address) const override {
270     *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
271     return OK;
272   }
273 
GetLocalAddress(IPEndPoint * address) const274   int GetLocalAddress(IPEndPoint* address) const override {
275     *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/);
276     return OK;
277   }
278 
NetLog() const279   const NetLogWithSource& NetLog() const override { return net_log_; }
280 
WasEverUsed() const281   bool WasEverUsed() const override { return true; }
282 
GetNegotiatedProtocol() const283   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
284 
GetSSLInfo(SSLInfo * ssl_info)285   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
286 
GetTotalReceivedBytes() const287   int64_t GetTotalReceivedBytes() const override {
288     NOTIMPLEMENTED();
289     return 0;
290   }
291 
ApplySocketTag(const SocketTag & tag)292   void ApplySocketTag(const SocketTag& tag) override {}
293 
294  private:
295   NetLogWithSource net_log_;
296   raw_ptr<FakeDataChannel> incoming_;
297   raw_ptr<FakeDataChannel> outgoing_;
298 };
299 
300 }  // namespace
301 
302 // Verify the correctness of the test helper classes first.
TEST(FakeSocketTest,DataTransfer)303 TEST(FakeSocketTest, DataTransfer) {
304   base::test::TaskEnvironment task_environment;
305 
306   // Establish channels between two sockets.
307   FakeDataChannel channel_1;
308   FakeDataChannel channel_2;
309   FakeSocket client(&channel_1, &channel_2);
310   FakeSocket server(&channel_2, &channel_1);
311 
312   const char kTestData[] = "testing123";
313   const int kTestDataSize = strlen(kTestData);
314   const int kReadBufSize = 1024;
315   auto write_buf = base::MakeRefCounted<StringIOBuffer>(kTestData);
316   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(kReadBufSize);
317 
318   // Write then read.
319   int written =
320       server.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
321                    TRAFFIC_ANNOTATION_FOR_TESTS);
322   EXPECT_GT(written, 0);
323   EXPECT_LE(written, kTestDataSize);
324 
325   int read =
326       client.Read(read_buf.get(), kReadBufSize, CompletionOnceCallback());
327   EXPECT_GT(read, 0);
328   EXPECT_LE(read, written);
329   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
330 
331   // Read then write.
332   TestCompletionCallback callback;
333   EXPECT_EQ(ERR_IO_PENDING,
334             server.Read(read_buf.get(), kReadBufSize, callback.callback()));
335 
336   written =
337       client.Write(write_buf.get(), kTestDataSize, CompletionOnceCallback(),
338                    TRAFFIC_ANNOTATION_FOR_TESTS);
339   EXPECT_GT(written, 0);
340   EXPECT_LE(written, kTestDataSize);
341 
342   read = callback.WaitForResult();
343   EXPECT_GT(read, 0);
344   EXPECT_LE(read, written);
345   EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read));
346 }
347 
348 class SSLServerSocketTest : public PlatformTest, public WithTaskEnvironment {
349  public:
SSLServerSocketTest()350   SSLServerSocketTest()
351       : ssl_config_service_(
352             std::make_unique<TestSSLConfigService>(SSLContextConfig())),
353         cert_verifier_(std::make_unique<MockCertVerifier>()),
354         client_cert_verifier_(std::make_unique<MockClientCertVerifier>()),
355         transport_security_state_(std::make_unique<TransportSecurityState>()),
356         ssl_client_session_cache_(std::make_unique<SSLClientSessionCache>(
357             SSLClientSessionCache::Config())) {}
358 
SetUp()359   void SetUp() override {
360     PlatformTest::SetUp();
361 
362     cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
363     client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
364 
365     server_cert_ =
366         ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der");
367     ASSERT_TRUE(server_cert_);
368     server_private_key_ = ReadTestKey("unittest.key.bin");
369     ASSERT_TRUE(server_private_key_);
370 
371     std::unique_ptr<crypto::RSAPrivateKey> key =
372         ReadTestKey("unittest.key.bin");
373     ASSERT_TRUE(key);
374     server_ssl_private_key_ = WrapOpenSSLPrivateKey(bssl::UpRef(key->key()));
375 
376     // Certificate provided by the host doesn't need authority.
377     client_ssl_config_.allowed_bad_certs.emplace_back(
378         server_cert_, CERT_STATUS_AUTHORITY_INVALID);
379 
380     client_context_ = std::make_unique<SSLClientContext>(
381         ssl_config_service_.get(), cert_verifier_.get(),
382         transport_security_state_.get(), ssl_client_session_cache_.get(),
383         nullptr);
384   }
385 
386  protected:
CreateContext()387   void CreateContext() {
388     client_socket_.reset();
389     server_socket_.reset();
390     channel_1_.reset();
391     channel_2_.reset();
392     server_context_ = CreateSSLServerContext(
393         server_cert_.get(), *server_private_key_, server_ssl_config_);
394   }
395 
CreateContextSSLPrivateKey()396   void CreateContextSSLPrivateKey() {
397     client_socket_.reset();
398     server_socket_.reset();
399     channel_1_.reset();
400     channel_2_.reset();
401     server_context_.reset();
402     server_context_ = CreateSSLServerContext(
403         server_cert_.get(), server_ssl_private_key_, server_ssl_config_);
404   }
405 
GetHostAndPort()406   static HostPortPair GetHostAndPort() { return HostPortPair("unittest", 0); }
407 
CreateSockets()408   void CreateSockets() {
409     client_socket_.reset();
410     server_socket_.reset();
411     channel_1_ = std::make_unique<FakeDataChannel>();
412     channel_2_ = std::make_unique<FakeDataChannel>();
413     std::unique_ptr<StreamSocket> client_connection =
414         std::make_unique<FakeSocket>(channel_1_.get(), channel_2_.get());
415     std::unique_ptr<StreamSocket> server_socket =
416         std::make_unique<FakeSocket>(channel_2_.get(), channel_1_.get());
417 
418     client_socket_ = client_context_->CreateSSLClientSocket(
419         std::move(client_connection), GetHostAndPort(), client_ssl_config_);
420     ASSERT_TRUE(client_socket_);
421 
422     server_socket_ =
423         server_context_->CreateSSLServerSocket(std::move(server_socket));
424     ASSERT_TRUE(server_socket_);
425   }
426 
427 // Client certificates are disabled on iOS.
428 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
ConfigureClientCertsForClient(const char * cert_file_name,const char * private_key_file_name)429   void ConfigureClientCertsForClient(const char* cert_file_name,
430                                      const char* private_key_file_name) {
431     scoped_refptr<X509Certificate> client_cert =
432         ImportCertFromFile(GetTestCertsDirectory(), cert_file_name);
433     ASSERT_TRUE(client_cert);
434 
435     std::unique_ptr<crypto::RSAPrivateKey> key =
436         ReadTestKey(private_key_file_name);
437     ASSERT_TRUE(key);
438 
439     client_context_->SetClientCertificate(
440         GetHostAndPort(), std::move(client_cert),
441         WrapOpenSSLPrivateKey(bssl::UpRef(key->key())));
442   }
443 
ConfigureClientCertsForServer()444   void ConfigureClientCertsForServer() {
445     server_ssl_config_.client_cert_type =
446         SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT;
447 
448     // "CN=B CA" - DER encoded DN of the issuer of client_1.pem
449     static const uint8_t kClientCertCAName[] = {
450         0x30, 0x0f, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55,
451         0x04, 0x03, 0x0c, 0x04, 0x42, 0x20, 0x43, 0x41};
452     server_ssl_config_.cert_authorities.emplace_back(
453         std::begin(kClientCertCAName), std::end(kClientCertCAName));
454 
455     scoped_refptr<X509Certificate> expected_client_cert(
456         ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName));
457     ASSERT_TRUE(expected_client_cert);
458 
459     client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK);
460 
461     server_ssl_config_.client_cert_verifier = client_cert_verifier_.get();
462   }
463 #endif  // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
464 
ReadTestKey(std::string_view name)465   std::unique_ptr<crypto::RSAPrivateKey> ReadTestKey(std::string_view name) {
466     base::FilePath certs_dir(GetTestCertsDirectory());
467     base::FilePath key_path = certs_dir.AppendASCII(name);
468     std::string key_string;
469     if (!base::ReadFileToString(key_path, &key_string))
470       return nullptr;
471     std::vector<uint8_t> key_vector(
472         reinterpret_cast<const uint8_t*>(key_string.data()),
473         reinterpret_cast<const uint8_t*>(key_string.data() +
474                                          key_string.length()));
475     std::unique_ptr<crypto::RSAPrivateKey> key(
476         crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
477     return key;
478   }
479 
PumpServerToClient()480   void PumpServerToClient() {
481     const int kReadBufSize = 1024;
482     scoped_refptr<StringIOBuffer> write_buf =
483         base::MakeRefCounted<StringIOBuffer>("testing123");
484     scoped_refptr<DrainableIOBuffer> read_buf =
485         base::MakeRefCounted<DrainableIOBuffer>(
486             base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
487     TestCompletionCallback write_callback;
488     TestCompletionCallback read_callback;
489     int server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
490                                            write_callback.callback(),
491                                            TRAFFIC_ANNOTATION_FOR_TESTS);
492     EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
493     int client_ret = client_socket_->Read(
494         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
495     EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
496 
497     server_ret = write_callback.GetResult(server_ret);
498     EXPECT_GT(server_ret, 0);
499     client_ret = read_callback.GetResult(client_ret);
500     ASSERT_GT(client_ret, 0);
501   }
502 
503   std::unique_ptr<FakeDataChannel> channel_1_;
504   std::unique_ptr<FakeDataChannel> channel_2_;
505   std::unique_ptr<TestSSLConfigService> ssl_config_service_;
506   std::unique_ptr<MockCertVerifier> cert_verifier_;
507   std::unique_ptr<MockClientCertVerifier> client_cert_verifier_;
508   SSLConfig client_ssl_config_;
509   // Note that this has a pointer to the `cert_verifier_`, so must be destroyed
510   // before that is.
511   SSLServerConfig server_ssl_config_;
512   std::unique_ptr<TransportSecurityState> transport_security_state_;
513   std::unique_ptr<SSLClientSessionCache> ssl_client_session_cache_;
514   std::unique_ptr<SSLClientContext> client_context_;
515   std::unique_ptr<SSLServerContext> server_context_;
516   std::unique_ptr<SSLClientSocket> client_socket_;
517   std::unique_ptr<SSLServerSocket> server_socket_;
518   std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
519   scoped_refptr<SSLPrivateKey> server_ssl_private_key_;
520   scoped_refptr<X509Certificate> server_cert_;
521 };
522 
523 class SSLServerSocketReadTest : public SSLServerSocketTest,
524                                 public ::testing::WithParamInterface<bool> {
525  protected:
SSLServerSocketReadTest()526   SSLServerSocketReadTest() : read_if_ready_enabled_(GetParam()) {}
527 
Read(StreamSocket * socket,IOBuffer * buf,int buf_len,CompletionOnceCallback callback)528   int Read(StreamSocket* socket,
529            IOBuffer* buf,
530            int buf_len,
531            CompletionOnceCallback callback) {
532     if (read_if_ready_enabled()) {
533       return socket->ReadIfReady(buf, buf_len, std::move(callback));
534     }
535     return socket->Read(buf, buf_len, std::move(callback));
536   }
537 
read_if_ready_enabled() const538   bool read_if_ready_enabled() const { return read_if_ready_enabled_; }
539 
540  private:
541   const bool read_if_ready_enabled_;
542 };
543 
544 INSTANTIATE_TEST_SUITE_P(/* no prefix */,
545                          SSLServerSocketReadTest,
546                          ::testing::Bool());
547 
548 // This test only executes creation of client and server sockets. This is to
549 // test that creation of sockets doesn't crash and have minimal code to run
550 // with memory leak/corruption checking tools.
TEST_F(SSLServerSocketTest,Initialize)551 TEST_F(SSLServerSocketTest, Initialize) {
552   ASSERT_NO_FATAL_FAILURE(CreateContext());
553   ASSERT_NO_FATAL_FAILURE(CreateSockets());
554 }
555 
556 // This test executes Connect() on SSLClientSocket and Handshake() on
557 // SSLServerSocket to make sure handshaking between the two sockets is
558 // completed successfully.
TEST_F(SSLServerSocketTest,Handshake)559 TEST_F(SSLServerSocketTest, Handshake) {
560   ASSERT_NO_FATAL_FAILURE(CreateContext());
561   ASSERT_NO_FATAL_FAILURE(CreateSockets());
562 
563   TestCompletionCallback handshake_callback;
564   int server_ret = server_socket_->Handshake(handshake_callback.callback());
565 
566   TestCompletionCallback connect_callback;
567   int client_ret = client_socket_->Connect(connect_callback.callback());
568 
569   client_ret = connect_callback.GetResult(client_ret);
570   server_ret = handshake_callback.GetResult(server_ret);
571 
572   ASSERT_THAT(client_ret, IsOk());
573   ASSERT_THAT(server_ret, IsOk());
574 
575   // Make sure the cert status is expected.
576   SSLInfo ssl_info;
577   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
578   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
579 
580   // The default cipher suite should be ECDHE and an AEAD.
581   uint16_t cipher_suite =
582       SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
583   const char* key_exchange;
584   const char* cipher;
585   const char* mac;
586   bool is_aead;
587   bool is_tls13;
588   SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
589                           cipher_suite);
590   EXPECT_TRUE(is_aead);
591 }
592 
593 // This test makes sure the session cache is working.
TEST_F(SSLServerSocketTest,HandshakeCached)594 TEST_F(SSLServerSocketTest, HandshakeCached) {
595   ASSERT_NO_FATAL_FAILURE(CreateContext());
596   ASSERT_NO_FATAL_FAILURE(CreateSockets());
597 
598   TestCompletionCallback handshake_callback;
599   int server_ret = server_socket_->Handshake(handshake_callback.callback());
600 
601   TestCompletionCallback connect_callback;
602   int client_ret = client_socket_->Connect(connect_callback.callback());
603 
604   client_ret = connect_callback.GetResult(client_ret);
605   server_ret = handshake_callback.GetResult(server_ret);
606 
607   ASSERT_THAT(client_ret, IsOk());
608   ASSERT_THAT(server_ret, IsOk());
609 
610   // Make sure the cert status is expected.
611   SSLInfo ssl_info;
612   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
613   EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
614   SSLInfo ssl_server_info;
615   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
616   EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
617 
618   // Pump client read to get new session tickets.
619   PumpServerToClient();
620 
621   // Make sure the second connection is cached.
622   ASSERT_NO_FATAL_FAILURE(CreateSockets());
623   TestCompletionCallback handshake_callback2;
624   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
625 
626   TestCompletionCallback connect_callback2;
627   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
628 
629   client_ret2 = connect_callback2.GetResult(client_ret2);
630   server_ret2 = handshake_callback2.GetResult(server_ret2);
631 
632   ASSERT_THAT(client_ret2, IsOk());
633   ASSERT_THAT(server_ret2, IsOk());
634 
635   // Make sure the cert status is expected.
636   SSLInfo ssl_info2;
637   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
638   EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
639   SSLInfo ssl_server_info2;
640   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
641   EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
642 }
643 
644 // This test makes sure the session cache separates out by server context.
TEST_F(SSLServerSocketTest,HandshakeCachedContextSwitch)645 TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) {
646   ASSERT_NO_FATAL_FAILURE(CreateContext());
647   ASSERT_NO_FATAL_FAILURE(CreateSockets());
648 
649   TestCompletionCallback handshake_callback;
650   int server_ret = server_socket_->Handshake(handshake_callback.callback());
651 
652   TestCompletionCallback connect_callback;
653   int client_ret = client_socket_->Connect(connect_callback.callback());
654 
655   client_ret = connect_callback.GetResult(client_ret);
656   server_ret = handshake_callback.GetResult(server_ret);
657 
658   ASSERT_THAT(client_ret, IsOk());
659   ASSERT_THAT(server_ret, IsOk());
660 
661   // Make sure the cert status is expected.
662   SSLInfo ssl_info;
663   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
664   EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
665   SSLInfo ssl_server_info;
666   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
667   EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
668 
669   // Make sure the second connection is NOT cached when using a new context.
670   ASSERT_NO_FATAL_FAILURE(CreateContext());
671   ASSERT_NO_FATAL_FAILURE(CreateSockets());
672 
673   TestCompletionCallback handshake_callback2;
674   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
675 
676   TestCompletionCallback connect_callback2;
677   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
678 
679   client_ret2 = connect_callback2.GetResult(client_ret2);
680   server_ret2 = handshake_callback2.GetResult(server_ret2);
681 
682   ASSERT_THAT(client_ret2, IsOk());
683   ASSERT_THAT(server_ret2, IsOk());
684 
685   // Make sure the cert status is expected.
686   SSLInfo ssl_info2;
687   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
688   EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
689   SSLInfo ssl_server_info2;
690   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
691   EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_FULL);
692 }
693 
694 // Client certificates are disabled on iOS.
695 #if BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
696 // This test executes Connect() on SSLClientSocket and Handshake() on
697 // SSLServerSocket to make sure handshaking between the two sockets is
698 // completed successfully, using client certificate.
TEST_F(SSLServerSocketTest,HandshakeWithClientCert)699 TEST_F(SSLServerSocketTest, HandshakeWithClientCert) {
700   scoped_refptr<X509Certificate> client_cert =
701       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
702   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
703       kClientCertFileName, kClientPrivateKeyFileName));
704   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
705   ASSERT_NO_FATAL_FAILURE(CreateContext());
706   ASSERT_NO_FATAL_FAILURE(CreateSockets());
707 
708   TestCompletionCallback handshake_callback;
709   int server_ret = server_socket_->Handshake(handshake_callback.callback());
710 
711   TestCompletionCallback connect_callback;
712   int client_ret = client_socket_->Connect(connect_callback.callback());
713 
714   client_ret = connect_callback.GetResult(client_ret);
715   server_ret = handshake_callback.GetResult(server_ret);
716 
717   ASSERT_THAT(client_ret, IsOk());
718   ASSERT_THAT(server_ret, IsOk());
719 
720   // Make sure the cert status is expected.
721   SSLInfo ssl_info;
722   client_socket_->GetSSLInfo(&ssl_info);
723   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
724   server_socket_->GetSSLInfo(&ssl_info);
725   ASSERT_TRUE(ssl_info.cert.get());
726   EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_info.cert.get()));
727 }
728 
729 // This test executes Connect() on SSLClientSocket and Handshake() twice on
730 // SSLServerSocket to make sure handshaking between the two sockets is
731 // completed successfully, using client certificate. The second connection is
732 // expected to succeed through the session cache.
TEST_F(SSLServerSocketTest,HandshakeWithClientCertCached)733 TEST_F(SSLServerSocketTest, HandshakeWithClientCertCached) {
734   scoped_refptr<X509Certificate> client_cert =
735       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
736   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
737       kClientCertFileName, kClientPrivateKeyFileName));
738   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
739   ASSERT_NO_FATAL_FAILURE(CreateContext());
740   ASSERT_NO_FATAL_FAILURE(CreateSockets());
741 
742   TestCompletionCallback handshake_callback;
743   int server_ret = server_socket_->Handshake(handshake_callback.callback());
744 
745   TestCompletionCallback connect_callback;
746   int client_ret = client_socket_->Connect(connect_callback.callback());
747 
748   client_ret = connect_callback.GetResult(client_ret);
749   server_ret = handshake_callback.GetResult(server_ret);
750 
751   ASSERT_THAT(client_ret, IsOk());
752   ASSERT_THAT(server_ret, IsOk());
753 
754   // Make sure the cert status is expected.
755   SSLInfo ssl_info;
756   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
757   EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
758   SSLInfo ssl_server_info;
759   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info));
760   ASSERT_TRUE(ssl_server_info.cert.get());
761   EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info.cert.get()));
762   EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL);
763   // Pump client read to get new session tickets.
764   PumpServerToClient();
765   server_socket_->Disconnect();
766   client_socket_->Disconnect();
767 
768   // Create the connection again.
769   ASSERT_NO_FATAL_FAILURE(CreateSockets());
770   TestCompletionCallback handshake_callback2;
771   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
772 
773   TestCompletionCallback connect_callback2;
774   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
775 
776   client_ret2 = connect_callback2.GetResult(client_ret2);
777   server_ret2 = handshake_callback2.GetResult(server_ret2);
778 
779   ASSERT_THAT(client_ret2, IsOk());
780   ASSERT_THAT(server_ret2, IsOk());
781 
782   // Make sure the cert status is expected.
783   SSLInfo ssl_info2;
784   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2));
785   EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
786   SSLInfo ssl_server_info2;
787   ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2));
788   ASSERT_TRUE(ssl_server_info2.cert.get());
789   EXPECT_TRUE(client_cert->EqualsExcludingChain(ssl_server_info2.cert.get()));
790   EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME);
791 }
792 
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSupplied)793 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) {
794   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
795   ASSERT_NO_FATAL_FAILURE(CreateContext());
796   ASSERT_NO_FATAL_FAILURE(CreateSockets());
797   // Use the default setting for the client socket, which is to not send
798   // a client certificate. This will cause the client to receive an
799   // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
800   // requested cert_authorities from the CertificateRequest sent by the
801   // server.
802 
803   TestCompletionCallback handshake_callback;
804   int server_ret = server_socket_->Handshake(handshake_callback.callback());
805 
806   TestCompletionCallback connect_callback;
807   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
808             connect_callback.GetResult(
809                 client_socket_->Connect(connect_callback.callback())));
810 
811   auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
812   client_socket_->GetSSLCertRequestInfo(request_info.get());
813 
814   // Check that the authority name that arrived in the CertificateRequest
815   // handshake message is as expected.
816   scoped_refptr<X509Certificate> client_cert =
817       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
818   ASSERT_TRUE(client_cert);
819   EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
820 
821   client_socket_->Disconnect();
822 
823   EXPECT_THAT(handshake_callback.GetResult(server_ret),
824               IsError(ERR_CONNECTION_CLOSED));
825 }
826 
TEST_F(SSLServerSocketTest,HandshakeWithClientCertRequiredNotSuppliedCached)827 TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSuppliedCached) {
828   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
829   ASSERT_NO_FATAL_FAILURE(CreateContext());
830   ASSERT_NO_FATAL_FAILURE(CreateSockets());
831   // Use the default setting for the client socket, which is to not send
832   // a client certificate. This will cause the client to receive an
833   // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
834   // requested cert_authorities from the CertificateRequest sent by the
835   // server.
836 
837   TestCompletionCallback handshake_callback;
838   int server_ret = server_socket_->Handshake(handshake_callback.callback());
839 
840   TestCompletionCallback connect_callback;
841   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
842             connect_callback.GetResult(
843                 client_socket_->Connect(connect_callback.callback())));
844 
845   auto request_info = base::MakeRefCounted<SSLCertRequestInfo>();
846   client_socket_->GetSSLCertRequestInfo(request_info.get());
847 
848   // Check that the authority name that arrived in the CertificateRequest
849   // handshake message is as expected.
850   scoped_refptr<X509Certificate> client_cert =
851       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
852   ASSERT_TRUE(client_cert);
853   EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
854 
855   client_socket_->Disconnect();
856 
857   EXPECT_THAT(handshake_callback.GetResult(server_ret),
858               IsError(ERR_CONNECTION_CLOSED));
859   server_socket_->Disconnect();
860 
861   // Below, check that the cache didn't store the result of a failed handshake.
862   ASSERT_NO_FATAL_FAILURE(CreateSockets());
863   TestCompletionCallback handshake_callback2;
864   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
865 
866   TestCompletionCallback connect_callback2;
867   EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
868             connect_callback2.GetResult(
869                 client_socket_->Connect(connect_callback2.callback())));
870 
871   auto request_info2 = base::MakeRefCounted<SSLCertRequestInfo>();
872   client_socket_->GetSSLCertRequestInfo(request_info2.get());
873 
874   // Check that the authority name that arrived in the CertificateRequest
875   // handshake message is as expected.
876   EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info2->cert_authorities));
877 
878   client_socket_->Disconnect();
879 
880   EXPECT_THAT(handshake_callback2.GetResult(server_ret2),
881               IsError(ERR_CONNECTION_CLOSED));
882 }
883 
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSupplied)884 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) {
885   scoped_refptr<X509Certificate> client_cert =
886       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
887   ASSERT_TRUE(client_cert);
888 
889   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
890       kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
891   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
892   ASSERT_NO_FATAL_FAILURE(CreateContext());
893   ASSERT_NO_FATAL_FAILURE(CreateSockets());
894 
895   TestCompletionCallback handshake_callback;
896   int server_ret = server_socket_->Handshake(handshake_callback.callback());
897 
898   TestCompletionCallback connect_callback;
899   int client_ret = client_socket_->Connect(connect_callback.callback());
900 
901   // In TLS 1.3, the client cert error isn't exposed until Read is called.
902   EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
903   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
904             handshake_callback.GetResult(server_ret));
905 
906   // Pump client read to get client cert error.
907   const int kReadBufSize = 1024;
908   scoped_refptr<DrainableIOBuffer> read_buf =
909       base::MakeRefCounted<DrainableIOBuffer>(
910           base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
911   TestCompletionCallback read_callback;
912   client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
913                                     read_callback.callback());
914   client_ret = read_callback.GetResult(client_ret);
915   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
916 }
917 
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedTLS12)918 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedTLS12) {
919   scoped_refptr<X509Certificate> client_cert =
920       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
921   ASSERT_TRUE(client_cert);
922 
923   client_ssl_config_.version_max_override = SSL_PROTOCOL_VERSION_TLS1_2;
924   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
925       kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
926   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
927   ASSERT_NO_FATAL_FAILURE(CreateContext());
928   ASSERT_NO_FATAL_FAILURE(CreateSockets());
929 
930   TestCompletionCallback handshake_callback;
931   int server_ret = server_socket_->Handshake(handshake_callback.callback());
932 
933   TestCompletionCallback connect_callback;
934   int client_ret = client_socket_->Connect(connect_callback.callback());
935 
936   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
937             connect_callback.GetResult(client_ret));
938   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
939             handshake_callback.GetResult(server_ret));
940 }
941 
TEST_F(SSLServerSocketTest,HandshakeWithWrongClientCertSuppliedCached)942 TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) {
943   scoped_refptr<X509Certificate> client_cert =
944       ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
945   ASSERT_TRUE(client_cert);
946 
947   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient(
948       kWrongClientCertFileName, kWrongClientPrivateKeyFileName));
949   ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer());
950   ASSERT_NO_FATAL_FAILURE(CreateContext());
951   ASSERT_NO_FATAL_FAILURE(CreateSockets());
952 
953   TestCompletionCallback handshake_callback;
954   int server_ret = server_socket_->Handshake(handshake_callback.callback());
955 
956   TestCompletionCallback connect_callback;
957   int client_ret = client_socket_->Connect(connect_callback.callback());
958 
959   // In TLS 1.3, the client cert error isn't exposed until Read is called.
960   EXPECT_EQ(OK, connect_callback.GetResult(client_ret));
961   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
962             handshake_callback.GetResult(server_ret));
963 
964   // Pump client read to get client cert error.
965   const int kReadBufSize = 1024;
966   scoped_refptr<DrainableIOBuffer> read_buf =
967       base::MakeRefCounted<DrainableIOBuffer>(
968           base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
969   TestCompletionCallback read_callback;
970   client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
971                                     read_callback.callback());
972   client_ret = read_callback.GetResult(client_ret);
973   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
974 
975   client_socket_->Disconnect();
976   server_socket_->Disconnect();
977 
978   // Below, check that the cache didn't store the result of a failed handshake.
979   ASSERT_NO_FATAL_FAILURE(CreateSockets());
980   TestCompletionCallback handshake_callback2;
981   int server_ret2 = server_socket_->Handshake(handshake_callback2.callback());
982 
983   TestCompletionCallback connect_callback2;
984   int client_ret2 = client_socket_->Connect(connect_callback2.callback());
985 
986   // In TLS 1.3, the client cert error isn't exposed until Read is called.
987   EXPECT_EQ(OK, connect_callback2.GetResult(client_ret2));
988   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
989             handshake_callback2.GetResult(server_ret2));
990 
991   // Pump client read to get client cert error.
992   client_ret = client_socket_->Read(read_buf.get(), read_buf->BytesRemaining(),
993                                     read_callback.callback());
994   client_ret = read_callback.GetResult(client_ret);
995   EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, client_ret);
996 }
997 #endif  // BUILDFLAG(ENABLE_CLIENT_CERTIFICATES)
998 
TEST_P(SSLServerSocketReadTest,DataTransfer)999 TEST_P(SSLServerSocketReadTest, DataTransfer) {
1000   ASSERT_NO_FATAL_FAILURE(CreateContext());
1001   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1002 
1003   // Establish connection.
1004   TestCompletionCallback connect_callback;
1005   int client_ret = client_socket_->Connect(connect_callback.callback());
1006   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1007 
1008   TestCompletionCallback handshake_callback;
1009   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1010   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1011 
1012   client_ret = connect_callback.GetResult(client_ret);
1013   ASSERT_THAT(client_ret, IsOk());
1014   server_ret = handshake_callback.GetResult(server_ret);
1015   ASSERT_THAT(server_ret, IsOk());
1016 
1017   const int kReadBufSize = 1024;
1018   scoped_refptr<StringIOBuffer> write_buf =
1019       base::MakeRefCounted<StringIOBuffer>("testing123");
1020   scoped_refptr<DrainableIOBuffer> read_buf =
1021       base::MakeRefCounted<DrainableIOBuffer>(
1022           base::MakeRefCounted<IOBufferWithSize>(kReadBufSize), kReadBufSize);
1023 
1024   // Write then read.
1025   TestCompletionCallback write_callback;
1026   TestCompletionCallback read_callback;
1027   server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1028                                      write_callback.callback(),
1029                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1030   EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1031   client_ret = client_socket_->Read(
1032       read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1033   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1034 
1035   server_ret = write_callback.GetResult(server_ret);
1036   EXPECT_GT(server_ret, 0);
1037   client_ret = read_callback.GetResult(client_ret);
1038   ASSERT_GT(client_ret, 0);
1039 
1040   read_buf->DidConsume(client_ret);
1041   while (read_buf->BytesConsumed() < write_buf->size()) {
1042     client_ret = client_socket_->Read(
1043         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
1044     EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1045     client_ret = read_callback.GetResult(client_ret);
1046     ASSERT_GT(client_ret, 0);
1047     read_buf->DidConsume(client_ret);
1048   }
1049   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1050   read_buf->SetOffset(0);
1051   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1052 
1053   // Read then write.
1054   write_buf = base::MakeRefCounted<StringIOBuffer>("hello123");
1055   server_ret = Read(server_socket_.get(), read_buf.get(),
1056                     read_buf->BytesRemaining(), read_callback.callback());
1057   EXPECT_EQ(server_ret, ERR_IO_PENDING);
1058   client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1059                                      write_callback.callback(),
1060                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1061   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1062 
1063   server_ret = read_callback.GetResult(server_ret);
1064   if (read_if_ready_enabled()) {
1065     // ReadIfReady signals the data is available but does not consume it.
1066     // The data is consumed later below.
1067     ASSERT_EQ(server_ret, OK);
1068   } else {
1069     ASSERT_GT(server_ret, 0);
1070     read_buf->DidConsume(server_ret);
1071   }
1072   client_ret = write_callback.GetResult(client_ret);
1073   EXPECT_GT(client_ret, 0);
1074 
1075   while (read_buf->BytesConsumed() < write_buf->size()) {
1076     server_ret = Read(server_socket_.get(), read_buf.get(),
1077                       read_buf->BytesRemaining(), read_callback.callback());
1078     // All the data was written above, so the data should be synchronously
1079     // available out of both Read() and ReadIfReady().
1080     ASSERT_GT(server_ret, 0);
1081     read_buf->DidConsume(server_ret);
1082   }
1083   EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed());
1084   read_buf->SetOffset(0);
1085   EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
1086 }
1087 
1088 // A regression test for bug 127822 (http://crbug.com/127822).
1089 // If the server closes the connection after the handshake is finished,
1090 // the client's Write() call should not cause an infinite loop.
1091 // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket.
TEST_F(SSLServerSocketTest,ClientWriteAfterServerClose)1092 TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
1093   ASSERT_NO_FATAL_FAILURE(CreateContext());
1094   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1095 
1096   // Establish connection.
1097   TestCompletionCallback connect_callback;
1098   int client_ret = client_socket_->Connect(connect_callback.callback());
1099   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1100 
1101   TestCompletionCallback handshake_callback;
1102   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1103   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1104 
1105   client_ret = connect_callback.GetResult(client_ret);
1106   ASSERT_THAT(client_ret, IsOk());
1107   server_ret = handshake_callback.GetResult(server_ret);
1108   ASSERT_THAT(server_ret, IsOk());
1109 
1110   scoped_refptr<StringIOBuffer> write_buf =
1111       base::MakeRefCounted<StringIOBuffer>("testing123");
1112 
1113   // The server closes the connection. The server needs to write some
1114   // data first so that the client's Read() calls from the transport
1115   // socket won't return ERR_IO_PENDING.  This ensures that the client
1116   // will call Read() on the transport socket again.
1117   TestCompletionCallback write_callback;
1118   server_ret = server_socket_->Write(write_buf.get(), write_buf->size(),
1119                                      write_callback.callback(),
1120                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1121   EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
1122 
1123   server_ret = write_callback.GetResult(server_ret);
1124   EXPECT_GT(server_ret, 0);
1125 
1126   server_socket_->Disconnect();
1127 
1128   // The client writes some data. This should not cause an infinite loop.
1129   client_ret = client_socket_->Write(write_buf.get(), write_buf->size(),
1130                                      write_callback.callback(),
1131                                      TRAFFIC_ANNOTATION_FOR_TESTS);
1132   EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING);
1133 
1134   client_ret = write_callback.GetResult(client_ret);
1135   EXPECT_GT(client_ret, 0);
1136 
1137   base::RunLoop run_loop;
1138   base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
1139       FROM_HERE, run_loop.QuitClosure(), base::Milliseconds(10));
1140   run_loop.Run();
1141 }
1142 
1143 // This test executes ExportKeyingMaterial() on the client and server sockets,
1144 // after connecting them, and verifies that the results match.
1145 // This test will fail if False Start is enabled (see crbug.com/90208).
TEST_F(SSLServerSocketTest,ExportKeyingMaterial)1146 TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
1147   ASSERT_NO_FATAL_FAILURE(CreateContext());
1148   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1149 
1150   TestCompletionCallback connect_callback;
1151   int client_ret = client_socket_->Connect(connect_callback.callback());
1152   ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
1153 
1154   TestCompletionCallback handshake_callback;
1155   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1156   ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
1157 
1158   if (client_ret == ERR_IO_PENDING) {
1159     ASSERT_THAT(connect_callback.WaitForResult(), IsOk());
1160   }
1161   if (server_ret == ERR_IO_PENDING) {
1162     ASSERT_THAT(handshake_callback.WaitForResult(), IsOk());
1163   }
1164 
1165   const int kKeyingMaterialSize = 32;
1166   const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test";
1167   const char kKeyingContext[] = "";
1168   unsigned char server_out[kKeyingMaterialSize];
1169   int rv = server_socket_->ExportKeyingMaterial(
1170       kKeyingLabel, false, kKeyingContext, server_out, sizeof(server_out));
1171   ASSERT_THAT(rv, IsOk());
1172 
1173   unsigned char client_out[kKeyingMaterialSize];
1174   rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, false, kKeyingContext,
1175                                             client_out, sizeof(client_out));
1176   ASSERT_THAT(rv, IsOk());
1177   EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out)));
1178 
1179   const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad";
1180   unsigned char client_bad[kKeyingMaterialSize];
1181   rv = client_socket_->ExportKeyingMaterial(
1182       kKeyingLabelBad, false, kKeyingContext, client_bad, sizeof(client_bad));
1183   ASSERT_EQ(rv, OK);
1184   EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out)));
1185 }
1186 
1187 // Verifies that SSLConfig::require_ecdhe flags works properly.
TEST_F(SSLServerSocketTest,RequireEcdheFlag)1188 TEST_F(SSLServerSocketTest, RequireEcdheFlag) {
1189   // Disable all ECDHE suites on the client side.
1190   SSLContextConfig config;
1191   config.disabled_cipher_suites.assign(
1192       kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1193 
1194   // Legacy RSA key exchange ciphers only exist in TLS 1.2 and below.
1195   config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1196   ssl_config_service_->UpdateSSLConfigAndNotify(config);
1197 
1198   // Require ECDHE on the server.
1199   server_ssl_config_.require_ecdhe = true;
1200 
1201   ASSERT_NO_FATAL_FAILURE(CreateContext());
1202   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1203 
1204   TestCompletionCallback connect_callback;
1205   int client_ret = client_socket_->Connect(connect_callback.callback());
1206 
1207   TestCompletionCallback handshake_callback;
1208   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1209 
1210   client_ret = connect_callback.GetResult(client_ret);
1211   server_ret = handshake_callback.GetResult(server_ret);
1212 
1213   ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1214   ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1215 }
1216 
1217 // This test executes Connect() on SSLClientSocket and Handshake() on
1218 // SSLServerSocket to make sure handshaking between the two sockets is
1219 // completed successfully. The server key is represented by SSLPrivateKey.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKey)1220 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKey) {
1221   ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1222   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1223 
1224   TestCompletionCallback handshake_callback;
1225   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1226 
1227   TestCompletionCallback connect_callback;
1228   int client_ret = client_socket_->Connect(connect_callback.callback());
1229 
1230   client_ret = connect_callback.GetResult(client_ret);
1231   server_ret = handshake_callback.GetResult(server_ret);
1232 
1233   ASSERT_THAT(client_ret, IsOk());
1234   ASSERT_THAT(server_ret, IsOk());
1235 
1236   // Make sure the cert status is expected.
1237   SSLInfo ssl_info;
1238   ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info));
1239   EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
1240 
1241   // The default cipher suite should be ECDHE and an AEAD.
1242   uint16_t cipher_suite =
1243       SSLConnectionStatusToCipherSuite(ssl_info.connection_status);
1244   const char* key_exchange;
1245   const char* cipher;
1246   const char* mac;
1247   bool is_aead;
1248   bool is_tls13;
1249   SSLCipherSuiteToStrings(&key_exchange, &cipher, &mac, &is_aead, &is_tls13,
1250                           cipher_suite);
1251   EXPECT_TRUE(is_aead);
1252 }
1253 
1254 namespace {
1255 
1256 // Helper that wraps an underlying SSLPrivateKey to allow the test to
1257 // do some work immediately before a `Sign()` operation is performed.
1258 class SSLPrivateKeyHook : public SSLPrivateKey {
1259  public:
SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,base::RepeatingClosure on_sign)1260   SSLPrivateKeyHook(scoped_refptr<SSLPrivateKey> private_key,
1261                     base::RepeatingClosure on_sign)
1262       : private_key_(std::move(private_key)), on_sign_(std::move(on_sign)) {}
1263 
1264   // SSLPrivateKey implementation.
GetProviderName()1265   std::string GetProviderName() override {
1266     return private_key_->GetProviderName();
1267   }
GetAlgorithmPreferences()1268   std::vector<uint16_t> GetAlgorithmPreferences() override {
1269     return private_key_->GetAlgorithmPreferences();
1270   }
Sign(uint16_t algorithm,base::span<const uint8_t> input,SignCallback callback)1271   void Sign(uint16_t algorithm,
1272             base::span<const uint8_t> input,
1273             SignCallback callback) override {
1274     on_sign_.Run();
1275     private_key_->Sign(algorithm, input, std::move(callback));
1276   }
1277 
1278  private:
1279   ~SSLPrivateKeyHook() override = default;
1280 
1281   const scoped_refptr<SSLPrivateKey> private_key_;
1282   const base::RepeatingClosure on_sign_;
1283 };
1284 
1285 }  // namespace
1286 
1287 // Verifies that if the client disconnects while during private key signing then
1288 // the disconnection is correctly reported to the `Handshake()` completion
1289 // callback, with `ERR_CONNECTION_CLOSED`.
1290 // This is a regression test for crbug.com/1449461.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError)1291 TEST_F(SSLServerSocketTest,
1292        HandshakeServerSSLPrivateKeyDisconnectDuringSigning_ReturnsError) {
1293   auto on_sign = base::BindLambdaForTesting([&]() {
1294     client_socket_->Disconnect();
1295     ASSERT_FALSE(client_socket_->IsConnected());
1296   });
1297   server_ssl_private_key_ = base::MakeRefCounted<SSLPrivateKeyHook>(
1298       std::move(server_ssl_private_key_), on_sign);
1299   ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1300   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1301 
1302   TestCompletionCallback handshake_callback;
1303   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1304   ASSERT_EQ(server_ret, net::ERR_IO_PENDING);
1305 
1306   TestCompletionCallback connect_callback;
1307   client_socket_->Connect(connect_callback.callback());
1308 
1309   // If resuming the handshake after private-key signing is not handled
1310   // correctly as per crbug.com/1449461 then the test will hang and timeout
1311   // at this point, due to the server-side completion callback not being
1312   // correctly invoked.
1313   server_ret = handshake_callback.GetResult(server_ret);
1314   EXPECT_EQ(server_ret, net::ERR_CONNECTION_CLOSED);
1315 }
1316 
1317 // Verifies that non-ECDHE ciphers are disabled when using SSLPrivateKey as the
1318 // server key.
TEST_F(SSLServerSocketTest,HandshakeServerSSLPrivateKeyRequireEcdhe)1319 TEST_F(SSLServerSocketTest, HandshakeServerSSLPrivateKeyRequireEcdhe) {
1320   // Disable all ECDHE suites on the client side.
1321   SSLContextConfig config;
1322   config.disabled_cipher_suites.assign(
1323       kEcdheCiphers, kEcdheCiphers + std::size(kEcdheCiphers));
1324   // TLS 1.3 always works with SSLPrivateKey.
1325   config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
1326   ssl_config_service_->UpdateSSLConfigAndNotify(config);
1327 
1328   ASSERT_NO_FATAL_FAILURE(CreateContextSSLPrivateKey());
1329   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1330 
1331   TestCompletionCallback connect_callback;
1332   int client_ret = client_socket_->Connect(connect_callback.callback());
1333 
1334   TestCompletionCallback handshake_callback;
1335   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1336 
1337   client_ret = connect_callback.GetResult(client_ret);
1338   server_ret = handshake_callback.GetResult(server_ret);
1339 
1340   ASSERT_THAT(client_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1341   ASSERT_THAT(server_ret, IsError(ERR_SSL_VERSION_OR_CIPHER_MISMATCH));
1342 }
1343 
1344 class SSLServerSocketAlpsTest
1345     : public SSLServerSocketTest,
1346       public ::testing::WithParamInterface<std::tuple<bool, bool>> {
1347  public:
SSLServerSocketAlpsTest()1348   SSLServerSocketAlpsTest()
1349       : client_alps_enabled_(std::get<0>(GetParam())),
1350         server_alps_enabled_(std::get<1>(GetParam())) {}
1351   ~SSLServerSocketAlpsTest() override = default;
1352   const bool client_alps_enabled_;
1353   const bool server_alps_enabled_;
1354 };
1355 
1356 INSTANTIATE_TEST_SUITE_P(All,
1357                          SSLServerSocketAlpsTest,
1358                          ::testing::Combine(::testing::Bool(),
1359                                             ::testing::Bool()));
1360 
TEST_P(SSLServerSocketAlpsTest,Alps)1361 TEST_P(SSLServerSocketAlpsTest, Alps) {
1362   const std::string server_data = "server sends some test data";
1363   const std::string client_data = "client also sends some data";
1364 
1365   server_ssl_config_.alpn_protos = {kProtoHTTP2};
1366   if (server_alps_enabled_) {
1367     server_ssl_config_.application_settings[kProtoHTTP2] =
1368         std::vector<uint8_t>(server_data.begin(), server_data.end());
1369   }
1370 
1371   client_ssl_config_.alpn_protos = {kProtoHTTP2};
1372   if (client_alps_enabled_) {
1373     client_ssl_config_.application_settings[kProtoHTTP2] =
1374         std::vector<uint8_t>(client_data.begin(), client_data.end());
1375   }
1376 
1377   ASSERT_NO_FATAL_FAILURE(CreateContext());
1378   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1379 
1380   TestCompletionCallback handshake_callback;
1381   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1382 
1383   TestCompletionCallback connect_callback;
1384   int client_ret = client_socket_->Connect(connect_callback.callback());
1385 
1386   client_ret = connect_callback.GetResult(client_ret);
1387   server_ret = handshake_callback.GetResult(server_ret);
1388 
1389   ASSERT_THAT(client_ret, IsOk());
1390   ASSERT_THAT(server_ret, IsOk());
1391 
1392   // ALPS is negotiated only if ALPS is enabled both on client and server.
1393   const auto alps_data_received_by_client =
1394       client_socket_->GetPeerApplicationSettings();
1395   const auto alps_data_received_by_server =
1396       server_socket_->GetPeerApplicationSettings();
1397 
1398   if (client_alps_enabled_ && server_alps_enabled_) {
1399     ASSERT_TRUE(alps_data_received_by_client.has_value());
1400     EXPECT_EQ(server_data, alps_data_received_by_client.value());
1401     ASSERT_TRUE(alps_data_received_by_server.has_value());
1402     EXPECT_EQ(client_data, alps_data_received_by_server.value());
1403   } else {
1404     EXPECT_FALSE(alps_data_received_by_client.has_value());
1405     EXPECT_FALSE(alps_data_received_by_server.has_value());
1406   }
1407 }
1408 
1409 // Test that CancelReadIfReady works.
TEST_F(SSLServerSocketTest,CancelReadIfReady)1410 TEST_F(SSLServerSocketTest, CancelReadIfReady) {
1411   ASSERT_NO_FATAL_FAILURE(CreateContext());
1412   ASSERT_NO_FATAL_FAILURE(CreateSockets());
1413 
1414   TestCompletionCallback connect_callback;
1415   int client_ret = client_socket_->Connect(connect_callback.callback());
1416   TestCompletionCallback handshake_callback;
1417   int server_ret = server_socket_->Handshake(handshake_callback.callback());
1418   ASSERT_THAT(connect_callback.GetResult(client_ret), IsOk());
1419   ASSERT_THAT(handshake_callback.GetResult(server_ret), IsOk());
1420 
1421   // Attempt to read from the server socket. There will not be anything to read.
1422   // Cancel the read immediately afterwards.
1423   TestCompletionCallback read_callback;
1424   auto read_buf = base::MakeRefCounted<IOBufferWithSize>(1);
1425   int read_ret =
1426       server_socket_->ReadIfReady(read_buf.get(), 1, read_callback.callback());
1427   ASSERT_THAT(read_ret, IsError(ERR_IO_PENDING));
1428   ASSERT_THAT(server_socket_->CancelReadIfReady(), IsOk());
1429 
1430   // After the client writes data, the server should still not pick up a result.
1431   auto write_buf = base::MakeRefCounted<StringIOBuffer>("a");
1432   TestCompletionCallback write_callback;
1433   ASSERT_EQ(write_callback.GetResult(client_socket_->Write(
1434                 write_buf.get(), write_buf->size(), write_callback.callback(),
1435                 TRAFFIC_ANNOTATION_FOR_TESTS)),
1436             write_buf->size());
1437   base::RunLoop().RunUntilIdle();
1438   EXPECT_FALSE(read_callback.have_result());
1439 
1440   // After a canceled read, future reads are still possible.
1441   while (true) {
1442     TestCompletionCallback read_callback2;
1443     read_ret = server_socket_->ReadIfReady(read_buf.get(), 1,
1444                                            read_callback2.callback());
1445     if (read_ret != ERR_IO_PENDING) {
1446       break;
1447     }
1448     ASSERT_THAT(read_callback2.GetResult(read_ret), IsOk());
1449   }
1450   ASSERT_EQ(1, read_ret);
1451   EXPECT_EQ(read_buf->data()[0], 'a');
1452 }
1453 
1454 }  // namespace net
1455