xref: /aosp_15_r20/external/webrtc/p2p/base/dtls_transport_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright 2011 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "p2p/base/dtls_transport.h"
12 
13 #include <algorithm>
14 #include <memory>
15 #include <set>
16 #include <utility>
17 
18 #include "absl/strings/string_view.h"
19 #include "api/dtls_transport_interface.h"
20 #include "p2p/base/fake_ice_transport.h"
21 #include "p2p/base/packet_transport_internal.h"
22 #include "rtc_base/checks.h"
23 #include "rtc_base/dscp.h"
24 #include "rtc_base/gunit.h"
25 #include "rtc_base/helpers.h"
26 #include "rtc_base/rtc_certificate.h"
27 #include "rtc_base/ssl_adapter.h"
28 #include "rtc_base/ssl_identity.h"
29 #include "rtc_base/ssl_stream_adapter.h"
30 
31 #define MAYBE_SKIP_TEST(feature)                                  \
32   if (!(rtc::SSLStreamAdapter::feature())) {                      \
33     RTC_LOG(LS_INFO) << #feature " feature disabled... skipping"; \
34     return;                                                       \
35   }
36 
37 namespace cricket {
38 
39 static const size_t kPacketNumOffset = 8;
40 static const size_t kPacketHeaderLen = 12;
41 static const int kFakePacketId = 0x1234;
42 static const int kTimeout = 10000;
43 
IsRtpLeadByte(uint8_t b)44 static bool IsRtpLeadByte(uint8_t b) {
45   return ((b & 0xC0) == 0x80);
46 }
47 
48 // `modify_digest` is used to set modified fingerprints that are meant to fail
49 // validation.
SetRemoteFingerprintFromCert(DtlsTransport * transport,const rtc::scoped_refptr<rtc::RTCCertificate> & cert,bool modify_digest=false)50 void SetRemoteFingerprintFromCert(
51     DtlsTransport* transport,
52     const rtc::scoped_refptr<rtc::RTCCertificate>& cert,
53     bool modify_digest = false) {
54   std::unique_ptr<rtc::SSLFingerprint> fingerprint =
55       rtc::SSLFingerprint::CreateFromCertificate(*cert);
56   if (modify_digest) {
57     ++fingerprint->digest.MutableData()[0];
58   }
59 
60   // Even if digest is verified to be incorrect, should fail asynchronously.
61   EXPECT_TRUE(
62       transport
63           ->SetRemoteParameters(
64               fingerprint->algorithm,
65               reinterpret_cast<const uint8_t*>(fingerprint->digest.data()),
66               fingerprint->digest.size(), absl::nullopt)
67           .ok());
68 }
69 
70 class DtlsTestClient : public sigslot::has_slots<> {
71  public:
DtlsTestClient(absl::string_view name)72   explicit DtlsTestClient(absl::string_view name) : name_(name) {}
CreateCertificate(rtc::KeyType key_type)73   void CreateCertificate(rtc::KeyType key_type) {
74     certificate_ =
75         rtc::RTCCertificate::Create(rtc::SSLIdentity::Create(name_, key_type));
76   }
certificate()77   const rtc::scoped_refptr<rtc::RTCCertificate>& certificate() {
78     return certificate_;
79   }
SetupMaxProtocolVersion(rtc::SSLProtocolVersion version)80   void SetupMaxProtocolVersion(rtc::SSLProtocolVersion version) {
81     ssl_max_version_ = version;
82   }
83   // Set up fake ICE transport and real DTLS transport under test.
SetupTransports(IceRole role,int async_delay_ms=0)84   void SetupTransports(IceRole role, int async_delay_ms = 0) {
85     fake_ice_transport_.reset(new FakeIceTransport("fake", 0));
86     fake_ice_transport_->SetAsync(true);
87     fake_ice_transport_->SetAsyncDelay(async_delay_ms);
88     fake_ice_transport_->SetIceRole(role);
89     fake_ice_transport_->SetIceTiebreaker((role == ICEROLE_CONTROLLING) ? 1
90                                                                         : 2);
91     // Hook the raw packets so that we can verify they are encrypted.
92     fake_ice_transport_->SignalReadPacket.connect(
93         this, &DtlsTestClient::OnFakeIceTransportReadPacket);
94 
95     dtls_transport_ = std::make_unique<DtlsTransport>(
96         fake_ice_transport_.get(), webrtc::CryptoOptions(),
97         /*event_log=*/nullptr, ssl_max_version_);
98     // Note: Certificate may be null here if testing passthrough.
99     dtls_transport_->SetLocalCertificate(certificate_);
100     dtls_transport_->SignalWritableState.connect(
101         this, &DtlsTestClient::OnTransportWritableState);
102     dtls_transport_->SignalReadPacket.connect(
103         this, &DtlsTestClient::OnTransportReadPacket);
104     dtls_transport_->SignalSentPacket.connect(
105         this, &DtlsTestClient::OnTransportSentPacket);
106   }
107 
fake_ice_transport()108   FakeIceTransport* fake_ice_transport() {
109     return static_cast<FakeIceTransport*>(dtls_transport_->ice_transport());
110   }
111 
dtls_transport()112   DtlsTransport* dtls_transport() { return dtls_transport_.get(); }
113 
114   // Simulate fake ICE transports connecting.
Connect(DtlsTestClient * peer,bool asymmetric)115   bool Connect(DtlsTestClient* peer, bool asymmetric) {
116     fake_ice_transport()->SetDestination(peer->fake_ice_transport(),
117                                          asymmetric);
118     return true;
119   }
120 
received_dtls_client_hellos() const121   int received_dtls_client_hellos() const {
122     return received_dtls_client_hellos_;
123   }
124 
received_dtls_server_hellos() const125   int received_dtls_server_hellos() const {
126     return received_dtls_server_hellos_;
127   }
128 
CheckRole(rtc::SSLRole role)129   void CheckRole(rtc::SSLRole role) {
130     if (role == rtc::SSL_CLIENT) {
131       ASSERT_EQ(0, received_dtls_client_hellos_);
132       ASSERT_GT(received_dtls_server_hellos_, 0);
133     } else {
134       ASSERT_GT(received_dtls_client_hellos_, 0);
135       ASSERT_EQ(0, received_dtls_server_hellos_);
136     }
137   }
138 
CheckSrtp(int expected_crypto_suite)139   void CheckSrtp(int expected_crypto_suite) {
140     int crypto_suite;
141     bool rv = dtls_transport_->GetSrtpCryptoSuite(&crypto_suite);
142     if (dtls_transport_->IsDtlsActive() && expected_crypto_suite) {
143       ASSERT_TRUE(rv);
144       ASSERT_EQ(crypto_suite, expected_crypto_suite);
145     } else {
146       ASSERT_FALSE(rv);
147     }
148   }
149 
CheckSsl()150   void CheckSsl() {
151     int cipher;
152     bool rv = dtls_transport_->GetSslCipherSuite(&cipher);
153     if (dtls_transport_->IsDtlsActive()) {
154       ASSERT_TRUE(rv);
155       EXPECT_TRUE(
156           rtc::SSLStreamAdapter::IsAcceptableCipher(cipher, rtc::KT_DEFAULT));
157     } else {
158       ASSERT_FALSE(rv);
159     }
160   }
161 
SendPackets(size_t size,size_t count,bool srtp)162   void SendPackets(size_t size, size_t count, bool srtp) {
163     std::unique_ptr<char[]> packet(new char[size]);
164     size_t sent = 0;
165     do {
166       // Fill the packet with a known value and a sequence number to check
167       // against, and make sure that it doesn't look like DTLS.
168       memset(packet.get(), sent & 0xff, size);
169       packet[0] = (srtp) ? 0x80 : 0x00;
170       rtc::SetBE32(packet.get() + kPacketNumOffset,
171                    static_cast<uint32_t>(sent));
172 
173       // Only set the bypass flag if we've activated DTLS.
174       int flags = (certificate_ && srtp) ? PF_SRTP_BYPASS : 0;
175       rtc::PacketOptions packet_options;
176       packet_options.packet_id = kFakePacketId;
177       int rv = dtls_transport_->SendPacket(packet.get(), size, packet_options,
178                                            flags);
179       ASSERT_GT(rv, 0);
180       ASSERT_EQ(size, static_cast<size_t>(rv));
181       ++sent;
182     } while (sent < count);
183   }
184 
SendInvalidSrtpPacket(size_t size)185   int SendInvalidSrtpPacket(size_t size) {
186     std::unique_ptr<char[]> packet(new char[size]);
187     // Fill the packet with 0 to form an invalid SRTP packet.
188     memset(packet.get(), 0, size);
189 
190     rtc::PacketOptions packet_options;
191     return dtls_transport_->SendPacket(packet.get(), size, packet_options,
192                                        PF_SRTP_BYPASS);
193   }
194 
ExpectPackets(size_t size)195   void ExpectPackets(size_t size) {
196     packet_size_ = size;
197     received_.clear();
198   }
199 
NumPacketsReceived()200   size_t NumPacketsReceived() { return received_.size(); }
201 
202   // Inverse of SendPackets.
VerifyPacket(const char * data,size_t size,uint32_t * out_num)203   bool VerifyPacket(const char* data, size_t size, uint32_t* out_num) {
204     if (size != packet_size_ ||
205         (data[0] != 0 && static_cast<uint8_t>(data[0]) != 0x80)) {
206       return false;
207     }
208     uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset);
209     for (size_t i = kPacketHeaderLen; i < size; ++i) {
210       if (static_cast<uint8_t>(data[i]) != (packet_num & 0xff)) {
211         return false;
212       }
213     }
214     if (out_num) {
215       *out_num = packet_num;
216     }
217     return true;
218   }
VerifyEncryptedPacket(const char * data,size_t size)219   bool VerifyEncryptedPacket(const char* data, size_t size) {
220     // This is an encrypted data packet; let's make sure it's mostly random;
221     // less than 10% of the bytes should be equal to the cleartext packet.
222     if (size <= packet_size_) {
223       return false;
224     }
225     uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset);
226     int num_matches = 0;
227     for (size_t i = kPacketNumOffset; i < size; ++i) {
228       if (static_cast<uint8_t>(data[i]) == (packet_num & 0xff)) {
229         ++num_matches;
230       }
231     }
232     return (num_matches < ((static_cast<int>(size) - 5) / 10));
233   }
234 
235   // Transport callbacks
OnTransportWritableState(rtc::PacketTransportInternal * transport)236   void OnTransportWritableState(rtc::PacketTransportInternal* transport) {
237     RTC_LOG(LS_INFO) << name_ << ": Transport '" << transport->transport_name()
238                      << "' is writable";
239   }
240 
OnTransportReadPacket(rtc::PacketTransportInternal * transport,const char * data,size_t size,const int64_t &,int flags)241   void OnTransportReadPacket(rtc::PacketTransportInternal* transport,
242                              const char* data,
243                              size_t size,
244                              const int64_t& /* packet_time_us */,
245                              int flags) {
246     uint32_t packet_num = 0;
247     ASSERT_TRUE(VerifyPacket(data, size, &packet_num));
248     received_.insert(packet_num);
249     // Only DTLS-SRTP packets should have the bypass flag set.
250     int expected_flags =
251         (certificate_ && IsRtpLeadByte(data[0])) ? PF_SRTP_BYPASS : 0;
252     ASSERT_EQ(expected_flags, flags);
253   }
254 
OnTransportSentPacket(rtc::PacketTransportInternal * transport,const rtc::SentPacket & sent_packet)255   void OnTransportSentPacket(rtc::PacketTransportInternal* transport,
256                              const rtc::SentPacket& sent_packet) {
257     sent_packet_ = sent_packet;
258   }
259 
sent_packet() const260   rtc::SentPacket sent_packet() const { return sent_packet_; }
261 
262   // Hook into the raw packet stream to make sure DTLS packets are encrypted.
OnFakeIceTransportReadPacket(rtc::PacketTransportInternal * transport,const char * data,size_t size,const int64_t &,int flags)263   void OnFakeIceTransportReadPacket(rtc::PacketTransportInternal* transport,
264                                     const char* data,
265                                     size_t size,
266                                     const int64_t& /* packet_time_us */,
267                                     int flags) {
268     // Flags shouldn't be set on the underlying Transport packets.
269     ASSERT_EQ(0, flags);
270 
271     // Look at the handshake packets to see what role we played.
272     // Check that non-handshake packets are DTLS data or SRTP bypass.
273     if (data[0] == 22 && size > 17) {
274       if (data[13] == 1) {
275         ++received_dtls_client_hellos_;
276       } else if (data[13] == 2) {
277         ++received_dtls_server_hellos_;
278       }
279     } else if (dtls_transport_->IsDtlsActive() &&
280                !(data[0] >= 20 && data[0] <= 22)) {
281       ASSERT_TRUE(data[0] == 23 || IsRtpLeadByte(data[0]));
282       if (data[0] == 23) {
283         ASSERT_TRUE(VerifyEncryptedPacket(data, size));
284       } else if (IsRtpLeadByte(data[0])) {
285         ASSERT_TRUE(VerifyPacket(data, size, NULL));
286       }
287     }
288   }
289 
290  private:
291   std::string name_;
292   rtc::scoped_refptr<rtc::RTCCertificate> certificate_;
293   std::unique_ptr<FakeIceTransport> fake_ice_transport_;
294   std::unique_ptr<DtlsTransport> dtls_transport_;
295   size_t packet_size_ = 0u;
296   std::set<int> received_;
297   rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12;
298   int received_dtls_client_hellos_ = 0;
299   int received_dtls_server_hellos_ = 0;
300   rtc::SentPacket sent_packet_;
301 };
302 
303 // Base class for DtlsTransportTest and DtlsEventOrderingTest, which
304 // inherit from different variants of ::testing::Test.
305 //
306 // Note that this test always uses a FakeClock, due to the `fake_clock_` member
307 // variable.
308 class DtlsTransportTestBase {
309  public:
DtlsTransportTestBase()310   DtlsTransportTestBase() : client1_("P1"), client2_("P2"), use_dtls_(false) {}
311 
SetMaxProtocolVersions(rtc::SSLProtocolVersion c1,rtc::SSLProtocolVersion c2)312   void SetMaxProtocolVersions(rtc::SSLProtocolVersion c1,
313                               rtc::SSLProtocolVersion c2) {
314     client1_.SetupMaxProtocolVersion(c1);
315     client2_.SetupMaxProtocolVersion(c2);
316   }
317   // If not called, DtlsTransport will be used in SRTP bypass mode.
PrepareDtls(rtc::KeyType key_type)318   void PrepareDtls(rtc::KeyType key_type) {
319     client1_.CreateCertificate(key_type);
320     client2_.CreateCertificate(key_type);
321     use_dtls_ = true;
322   }
323 
324   // This test negotiates DTLS parameters before the underlying transports are
325   // writable. DtlsEventOrderingTest is responsible for exercising differerent
326   // orderings.
Connect(bool client1_server=true)327   bool Connect(bool client1_server = true) {
328     Negotiate(client1_server);
329     EXPECT_TRUE(client1_.Connect(&client2_, false));
330 
331     EXPECT_TRUE_SIMULATED_WAIT(client1_.dtls_transport()->writable() &&
332                                    client2_.dtls_transport()->writable(),
333                                kTimeout, fake_clock_);
334     if (!client1_.dtls_transport()->writable() ||
335         !client2_.dtls_transport()->writable())
336       return false;
337 
338     // Check that we used the right roles.
339     if (use_dtls_) {
340       client1_.CheckRole(client1_server ? rtc::SSL_SERVER : rtc::SSL_CLIENT);
341       client2_.CheckRole(client1_server ? rtc::SSL_CLIENT : rtc::SSL_SERVER);
342     }
343 
344     if (use_dtls_) {
345       // Check that we negotiated the right ciphers. Since GCM ciphers are not
346       // negotiated by default, we should end up with kSrtpAes128CmSha1_80.
347       client1_.CheckSrtp(rtc::kSrtpAes128CmSha1_80);
348       client2_.CheckSrtp(rtc::kSrtpAes128CmSha1_80);
349     } else {
350       // If DTLS isn't actually being used, GetSrtpCryptoSuite should return
351       // false.
352       client1_.CheckSrtp(rtc::kSrtpInvalidCryptoSuite);
353       client2_.CheckSrtp(rtc::kSrtpInvalidCryptoSuite);
354     }
355 
356     client1_.CheckSsl();
357     client2_.CheckSsl();
358 
359     return true;
360   }
361 
Negotiate(bool client1_server=true)362   void Negotiate(bool client1_server = true) {
363     client1_.SetupTransports(ICEROLE_CONTROLLING);
364     client2_.SetupTransports(ICEROLE_CONTROLLED);
365     client1_.dtls_transport()->SetDtlsRole(client1_server ? rtc::SSL_SERVER
366                                                           : rtc::SSL_CLIENT);
367     client2_.dtls_transport()->SetDtlsRole(client1_server ? rtc::SSL_CLIENT
368                                                           : rtc::SSL_SERVER);
369     if (client2_.certificate()) {
370       SetRemoteFingerprintFromCert(client1_.dtls_transport(),
371                                    client2_.certificate());
372     }
373     if (client1_.certificate()) {
374       SetRemoteFingerprintFromCert(client2_.dtls_transport(),
375                                    client1_.certificate());
376     }
377   }
378 
TestTransfer(size_t size,size_t count,bool srtp)379   void TestTransfer(size_t size, size_t count, bool srtp) {
380     RTC_LOG(LS_INFO) << "Expect packets, size=" << size;
381     client2_.ExpectPackets(size);
382     client1_.SendPackets(size, count, srtp);
383     EXPECT_EQ_SIMULATED_WAIT(count, client2_.NumPacketsReceived(), kTimeout,
384                              fake_clock_);
385   }
386 
387  protected:
388   rtc::AutoThread main_thread_;
389   rtc::ScopedFakeClock fake_clock_;
390   DtlsTestClient client1_;
391   DtlsTestClient client2_;
392   bool use_dtls_;
393   rtc::SSLProtocolVersion ssl_expected_version_;
394 };
395 
396 class DtlsTransportTest : public DtlsTransportTestBase,
397                           public ::testing::Test {};
398 
399 // Connect without DTLS, and transfer RTP data.
TEST_F(DtlsTransportTest,TestTransferRtp)400 TEST_F(DtlsTransportTest, TestTransferRtp) {
401   ASSERT_TRUE(Connect());
402   TestTransfer(1000, 100, /*srtp=*/false);
403 }
404 
405 // Test that the SignalSentPacket signal is wired up.
TEST_F(DtlsTransportTest,TestSignalSentPacket)406 TEST_F(DtlsTransportTest, TestSignalSentPacket) {
407   ASSERT_TRUE(Connect());
408   // Sanity check default value (-1).
409   ASSERT_EQ(client1_.sent_packet().send_time_ms, -1);
410   TestTransfer(1000, 100, false);
411   // Check that we get the expected fake packet ID, and a time of 0 from the
412   // fake clock.
413   EXPECT_EQ(kFakePacketId, client1_.sent_packet().packet_id);
414   EXPECT_GE(client1_.sent_packet().send_time_ms, 0);
415 }
416 
417 // Connect without DTLS, and transfer SRTP data.
TEST_F(DtlsTransportTest,TestTransferSrtp)418 TEST_F(DtlsTransportTest, TestTransferSrtp) {
419   ASSERT_TRUE(Connect());
420   TestTransfer(1000, 100, /*srtp=*/true);
421 }
422 
423 // Connect with DTLS, and transfer data over DTLS.
TEST_F(DtlsTransportTest,TestTransferDtls)424 TEST_F(DtlsTransportTest, TestTransferDtls) {
425   PrepareDtls(rtc::KT_DEFAULT);
426   ASSERT_TRUE(Connect());
427   TestTransfer(1000, 100, /*srtp=*/false);
428 }
429 
430 // Connect with DTLS, combine multiple DTLS records into one packet.
431 // Our DTLS implementation doesn't do this, but other implementations may;
432 // see https://tools.ietf.org/html/rfc6347#section-4.1.1.
433 // This has caused interoperability problems with ORTCLib in the past.
TEST_F(DtlsTransportTest,TestTransferDtlsCombineRecords)434 TEST_F(DtlsTransportTest, TestTransferDtlsCombineRecords) {
435   PrepareDtls(rtc::KT_DEFAULT);
436   ASSERT_TRUE(Connect());
437   // Our DTLS implementation always sends one record per packet, so to simulate
438   // an endpoint that sends multiple records per packet, we configure the fake
439   // ICE transport to combine every two consecutive packets into a single
440   // packet.
441   FakeIceTransport* transport = client1_.fake_ice_transport();
442   transport->combine_outgoing_packets(true);
443   TestTransfer(500, 100, /*srtp=*/false);
444 }
445 
446 class DtlsTransportVersionTest
447     : public DtlsTransportTestBase,
448       public ::testing::TestWithParam<
449           ::testing::tuple<rtc::SSLProtocolVersion, rtc::SSLProtocolVersion>> {
450 };
451 
452 // Test that an acceptable cipher suite is negotiated when different versions
453 // of DTLS are supported. Note that it's IsAcceptableCipher that does the actual
454 // work.
TEST_P(DtlsTransportVersionTest,TestCipherSuiteNegotiation)455 TEST_P(DtlsTransportVersionTest, TestCipherSuiteNegotiation) {
456   PrepareDtls(rtc::KT_DEFAULT);
457   SetMaxProtocolVersions(::testing::get<0>(GetParam()),
458                          ::testing::get<1>(GetParam()));
459   ASSERT_TRUE(Connect());
460 }
461 
462 // Will test every combination of 1.0/1.2 on the client and server.
463 INSTANTIATE_TEST_SUITE_P(
464     TestCipherSuiteNegotiation,
465     DtlsTransportVersionTest,
466     ::testing::Combine(::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
467                                          rtc::SSL_PROTOCOL_DTLS_12),
468                        ::testing::Values(rtc::SSL_PROTOCOL_DTLS_10,
469                                          rtc::SSL_PROTOCOL_DTLS_12)));
470 
471 // Connect with DTLS, negotiating DTLS-SRTP, and transfer SRTP using bypass.
TEST_F(DtlsTransportTest,TestTransferDtlsSrtp)472 TEST_F(DtlsTransportTest, TestTransferDtlsSrtp) {
473   PrepareDtls(rtc::KT_DEFAULT);
474   ASSERT_TRUE(Connect());
475   TestTransfer(1000, 100, /*srtp=*/true);
476 }
477 
478 // Connect with DTLS-SRTP, transfer an invalid SRTP packet, and expects -1
479 // returned.
TEST_F(DtlsTransportTest,TestTransferDtlsInvalidSrtpPacket)480 TEST_F(DtlsTransportTest, TestTransferDtlsInvalidSrtpPacket) {
481   PrepareDtls(rtc::KT_DEFAULT);
482   ASSERT_TRUE(Connect());
483   EXPECT_EQ(-1, client1_.SendInvalidSrtpPacket(100));
484 }
485 
486 // Create a single transport with DTLS, and send normal data and SRTP data on
487 // it.
TEST_F(DtlsTransportTest,TestTransferDtlsSrtpDemux)488 TEST_F(DtlsTransportTest, TestTransferDtlsSrtpDemux) {
489   PrepareDtls(rtc::KT_DEFAULT);
490   ASSERT_TRUE(Connect());
491   TestTransfer(1000, 100, /*srtp=*/false);
492   TestTransfer(1000, 100, /*srtp=*/true);
493 }
494 
495 // Test transferring when the "answerer" has the server role.
TEST_F(DtlsTransportTest,TestTransferDtlsSrtpAnswererIsPassive)496 TEST_F(DtlsTransportTest, TestTransferDtlsSrtpAnswererIsPassive) {
497   PrepareDtls(rtc::KT_DEFAULT);
498   ASSERT_TRUE(Connect(/*client1_server=*/false));
499   TestTransfer(1000, 100, /*srtp=*/true);
500 }
501 
502 // Test that renegotiation (setting same role and fingerprint again) can be
503 // started before the clients become connected in the first negotiation.
TEST_F(DtlsTransportTest,TestRenegotiateBeforeConnect)504 TEST_F(DtlsTransportTest, TestRenegotiateBeforeConnect) {
505   PrepareDtls(rtc::KT_DEFAULT);
506   // Note: This is doing the same thing Connect normally does, minus some
507   // additional checks not relevant for this test.
508   Negotiate();
509   Negotiate();
510   EXPECT_TRUE(client1_.Connect(&client2_, false));
511   EXPECT_TRUE_SIMULATED_WAIT(client1_.dtls_transport()->writable() &&
512                                  client2_.dtls_transport()->writable(),
513                              kTimeout, fake_clock_);
514   TestTransfer(1000, 100, true);
515 }
516 
517 // Test Certificates state after negotiation but before connection.
TEST_F(DtlsTransportTest,TestCertificatesBeforeConnect)518 TEST_F(DtlsTransportTest, TestCertificatesBeforeConnect) {
519   PrepareDtls(rtc::KT_DEFAULT);
520   Negotiate();
521 
522   // After negotiation, each side has a distinct local certificate, but still no
523   // remote certificate, because connection has not yet occurred.
524   auto certificate1 = client1_.dtls_transport()->GetLocalCertificate();
525   auto certificate2 = client2_.dtls_transport()->GetLocalCertificate();
526   ASSERT_NE(certificate1->GetSSLCertificate().ToPEMString(),
527             certificate2->GetSSLCertificate().ToPEMString());
528   ASSERT_FALSE(client1_.dtls_transport()->GetRemoteSSLCertChain());
529   ASSERT_FALSE(client2_.dtls_transport()->GetRemoteSSLCertChain());
530 }
531 
532 // Test Certificates state after connection.
TEST_F(DtlsTransportTest,TestCertificatesAfterConnect)533 TEST_F(DtlsTransportTest, TestCertificatesAfterConnect) {
534   PrepareDtls(rtc::KT_DEFAULT);
535   ASSERT_TRUE(Connect());
536 
537   // After connection, each side has a distinct local certificate.
538   auto certificate1 = client1_.dtls_transport()->GetLocalCertificate();
539   auto certificate2 = client2_.dtls_transport()->GetLocalCertificate();
540   ASSERT_NE(certificate1->GetSSLCertificate().ToPEMString(),
541             certificate2->GetSSLCertificate().ToPEMString());
542 
543   // Each side's remote certificate is the other side's local certificate.
544   std::unique_ptr<rtc::SSLCertChain> remote_cert1 =
545       client1_.dtls_transport()->GetRemoteSSLCertChain();
546   ASSERT_TRUE(remote_cert1);
547   ASSERT_EQ(1u, remote_cert1->GetSize());
548   ASSERT_EQ(remote_cert1->Get(0).ToPEMString(),
549             certificate2->GetSSLCertificate().ToPEMString());
550   std::unique_ptr<rtc::SSLCertChain> remote_cert2 =
551       client2_.dtls_transport()->GetRemoteSSLCertChain();
552   ASSERT_TRUE(remote_cert2);
553   ASSERT_EQ(1u, remote_cert2->GetSize());
554   ASSERT_EQ(remote_cert2->Get(0).ToPEMString(),
555             certificate1->GetSSLCertificate().ToPEMString());
556 }
557 
558 // Test that packets are retransmitted according to the expected schedule.
559 // Each time a timeout occurs, the retransmission timer should be doubled up to
560 // 60 seconds. The timer defaults to 1 second, but for WebRTC we should be
561 // initializing it to 50ms.
TEST_F(DtlsTransportTest,TestRetransmissionSchedule)562 TEST_F(DtlsTransportTest, TestRetransmissionSchedule) {
563   // We can only change the retransmission schedule with a recently-added
564   // BoringSSL API. Skip the test if not built with BoringSSL.
565   MAYBE_SKIP_TEST(IsBoringSsl);
566 
567   PrepareDtls(rtc::KT_DEFAULT);
568   // Exchange fingerprints and set SSL roles.
569   Negotiate();
570 
571   // Make client2_ writable, but not client1_.
572   // This means client1_ will send DTLS client hellos but get no response.
573   EXPECT_TRUE(client2_.Connect(&client1_, true));
574   EXPECT_TRUE_SIMULATED_WAIT(client2_.fake_ice_transport()->writable(),
575                              kTimeout, fake_clock_);
576 
577   // Wait for the first client hello to be sent.
578   EXPECT_EQ_WAIT(1, client1_.received_dtls_client_hellos(), kTimeout);
579   EXPECT_FALSE(client1_.fake_ice_transport()->writable());
580 
581   static int timeout_schedule_ms[] = {50,   100,  200,   400,   800,   1600,
582                                       3200, 6400, 12800, 25600, 51200, 60000};
583 
584   int expected_hellos = 1;
585   for (size_t i = 0;
586        i < (sizeof(timeout_schedule_ms) / sizeof(timeout_schedule_ms[0]));
587        ++i) {
588     // For each expected retransmission time, advance the fake clock a
589     // millisecond before the expected time and verify that no unexpected
590     // retransmissions were sent. Then advance it the final millisecond and
591     // verify that the expected retransmission was sent.
592     fake_clock_.AdvanceTime(
593         webrtc::TimeDelta::Millis(timeout_schedule_ms[i] - 1));
594     EXPECT_EQ(expected_hellos, client1_.received_dtls_client_hellos());
595     fake_clock_.AdvanceTime(webrtc::TimeDelta::Millis(1));
596     EXPECT_EQ(++expected_hellos, client1_.received_dtls_client_hellos());
597   }
598 }
599 
600 // The following events can occur in many different orders:
601 // 1. Caller receives remote fingerprint.
602 // 2. Caller is writable.
603 // 3. Caller receives ClientHello.
604 // 4. DTLS handshake finishes.
605 //
606 // The tests below cover all causally consistent permutations of these events;
607 // the caller must be writable and receive a ClientHello before the handshake
608 // finishes, but otherwise any ordering is possible.
609 //
610 // For each permutation, the test verifies that a connection is established and
611 // fingerprint verified without any DTLS packet needing to be retransmitted.
612 //
613 // Each permutation is also tested with valid and invalid fingerprints,
614 // ensuring that the handshake fails with an invalid fingerprint.
615 enum DtlsTransportEvent {
616   CALLER_RECEIVES_FINGERPRINT,
617   CALLER_WRITABLE,
618   CALLER_RECEIVES_CLIENTHELLO,
619   HANDSHAKE_FINISHES
620 };
621 
622 class DtlsEventOrderingTest
623     : public DtlsTransportTestBase,
624       public ::testing::TestWithParam<
625           ::testing::tuple<std::vector<DtlsTransportEvent>, bool>> {
626  protected:
627   // If `valid_fingerprint` is false, the caller will receive a fingerprint
628   // that doesn't match the callee's certificate, so the handshake should fail.
TestEventOrdering(const std::vector<DtlsTransportEvent> & events,bool valid_fingerprint)629   void TestEventOrdering(const std::vector<DtlsTransportEvent>& events,
630                          bool valid_fingerprint) {
631     // Pre-setup: Set local certificate on both caller and callee, and
632     // remote fingerprint on callee, but neither is writable and the caller
633     // doesn't have the callee's fingerprint.
634     PrepareDtls(rtc::KT_DEFAULT);
635     // Simulate packets being sent and arriving asynchronously.
636     // Otherwise the entire DTLS handshake would occur in one clock tick, and
637     // we couldn't inject method calls in the middle of it.
638     int simulated_delay_ms = 10;
639     client1_.SetupTransports(ICEROLE_CONTROLLING, simulated_delay_ms);
640     client2_.SetupTransports(ICEROLE_CONTROLLED, simulated_delay_ms);
641     // Similar to how NegotiateOrdering works.
642     client1_.dtls_transport()->SetDtlsRole(rtc::SSL_SERVER);
643     client2_.dtls_transport()->SetDtlsRole(rtc::SSL_CLIENT);
644     SetRemoteFingerprintFromCert(client2_.dtls_transport(),
645                                  client1_.certificate());
646 
647     for (DtlsTransportEvent e : events) {
648       switch (e) {
649         case CALLER_RECEIVES_FINGERPRINT:
650           if (valid_fingerprint) {
651             SetRemoteFingerprintFromCert(client1_.dtls_transport(),
652                                          client2_.certificate());
653           } else {
654             SetRemoteFingerprintFromCert(client1_.dtls_transport(),
655                                          client2_.certificate(),
656                                          true /*modify_digest*/);
657           }
658           break;
659         case CALLER_WRITABLE:
660           EXPECT_TRUE(client1_.Connect(&client2_, true));
661           EXPECT_TRUE_SIMULATED_WAIT(client1_.fake_ice_transport()->writable(),
662                                      kTimeout, fake_clock_);
663           break;
664         case CALLER_RECEIVES_CLIENTHELLO:
665           // Sanity check that a ClientHello hasn't already been received.
666           EXPECT_EQ(0, client1_.received_dtls_client_hellos());
667           // Making client2_ writable will cause it to send the ClientHello.
668           EXPECT_TRUE(client2_.Connect(&client1_, true));
669           EXPECT_TRUE_SIMULATED_WAIT(client2_.fake_ice_transport()->writable(),
670                                      kTimeout, fake_clock_);
671           EXPECT_EQ_SIMULATED_WAIT(1, client1_.received_dtls_client_hellos(),
672                                    kTimeout, fake_clock_);
673           break;
674         case HANDSHAKE_FINISHES:
675           // Sanity check that the handshake hasn't already finished.
676           EXPECT_FALSE(client1_.dtls_transport()->IsDtlsConnected() ||
677                        client1_.dtls_transport()->dtls_state() ==
678                            webrtc::DtlsTransportState::kFailed);
679           EXPECT_TRUE_SIMULATED_WAIT(
680               client1_.dtls_transport()->IsDtlsConnected() ||
681                   client1_.dtls_transport()->dtls_state() ==
682                       webrtc::DtlsTransportState::kFailed,
683               kTimeout, fake_clock_);
684           break;
685       }
686     }
687 
688     webrtc::DtlsTransportState expected_final_state =
689         valid_fingerprint ? webrtc::DtlsTransportState::kConnected
690                           : webrtc::DtlsTransportState::kFailed;
691     EXPECT_EQ_SIMULATED_WAIT(expected_final_state,
692                              client1_.dtls_transport()->dtls_state(), kTimeout,
693                              fake_clock_);
694     EXPECT_EQ_SIMULATED_WAIT(expected_final_state,
695                              client2_.dtls_transport()->dtls_state(), kTimeout,
696                              fake_clock_);
697 
698     // Transports should be writable iff there was a valid fingerprint.
699     EXPECT_EQ(valid_fingerprint, client1_.dtls_transport()->writable());
700     EXPECT_EQ(valid_fingerprint, client2_.dtls_transport()->writable());
701 
702     // Check that no hello needed to be retransmitted.
703     EXPECT_EQ(1, client1_.received_dtls_client_hellos());
704     EXPECT_EQ(1, client2_.received_dtls_server_hellos());
705 
706     if (valid_fingerprint) {
707       TestTransfer(1000, 100, false);
708     }
709   }
710 };
711 
TEST_P(DtlsEventOrderingTest,TestEventOrdering)712 TEST_P(DtlsEventOrderingTest, TestEventOrdering) {
713   TestEventOrdering(::testing::get<0>(GetParam()),
714                     ::testing::get<1>(GetParam()));
715 }
716 
717 INSTANTIATE_TEST_SUITE_P(
718     TestEventOrdering,
719     DtlsEventOrderingTest,
720     ::testing::Combine(
721         ::testing::Values(
722             std::vector<DtlsTransportEvent>{
723                 CALLER_RECEIVES_FINGERPRINT, CALLER_WRITABLE,
724                 CALLER_RECEIVES_CLIENTHELLO, HANDSHAKE_FINISHES},
725             std::vector<DtlsTransportEvent>{
726                 CALLER_WRITABLE, CALLER_RECEIVES_FINGERPRINT,
727                 CALLER_RECEIVES_CLIENTHELLO, HANDSHAKE_FINISHES},
728             std::vector<DtlsTransportEvent>{
729                 CALLER_WRITABLE, CALLER_RECEIVES_CLIENTHELLO,
730                 CALLER_RECEIVES_FINGERPRINT, HANDSHAKE_FINISHES},
731             std::vector<DtlsTransportEvent>{
732                 CALLER_WRITABLE, CALLER_RECEIVES_CLIENTHELLO,
733                 HANDSHAKE_FINISHES, CALLER_RECEIVES_FINGERPRINT},
734             std::vector<DtlsTransportEvent>{
735                 CALLER_RECEIVES_FINGERPRINT, CALLER_RECEIVES_CLIENTHELLO,
736                 CALLER_WRITABLE, HANDSHAKE_FINISHES},
737             std::vector<DtlsTransportEvent>{
738                 CALLER_RECEIVES_CLIENTHELLO, CALLER_RECEIVES_FINGERPRINT,
739                 CALLER_WRITABLE, HANDSHAKE_FINISHES},
740             std::vector<DtlsTransportEvent>{
741                 CALLER_RECEIVES_CLIENTHELLO, CALLER_WRITABLE,
742                 CALLER_RECEIVES_FINGERPRINT, HANDSHAKE_FINISHES},
743             std::vector<DtlsTransportEvent>{CALLER_RECEIVES_CLIENTHELLO,
744                                             CALLER_WRITABLE, HANDSHAKE_FINISHES,
745                                             CALLER_RECEIVES_FINGERPRINT}),
746         ::testing::Bool()));
747 
748 }  // namespace cricket
749