xref: /aosp_15_r20/external/webrtc/pc/dtls_transport_unittest.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright 2018 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "pc/dtls_transport.h"
12 
13 #include <utility>
14 #include <vector>
15 
16 #include "absl/types/optional.h"
17 #include "api/make_ref_counted.h"
18 #include "api/rtc_error.h"
19 #include "p2p/base/fake_dtls_transport.h"
20 #include "p2p/base/p2p_constants.h"
21 #include "rtc_base/fake_ssl_identity.h"
22 #include "rtc_base/gunit.h"
23 #include "rtc_base/rtc_certificate.h"
24 #include "rtc_base/ssl_identity.h"
25 #include "test/gmock.h"
26 #include "test/gtest.h"
27 
28 constexpr int kDefaultTimeout = 1000;  // milliseconds
29 constexpr int kNonsenseCipherSuite = 1234;
30 
31 using cricket::FakeDtlsTransport;
32 using ::testing::ElementsAre;
33 
34 namespace webrtc {
35 
36 class TestDtlsTransportObserver : public DtlsTransportObserverInterface {
37  public:
OnStateChange(DtlsTransportInformation info)38   void OnStateChange(DtlsTransportInformation info) override {
39     state_change_called_ = true;
40     states_.push_back(info.state());
41     info_ = info;
42   }
43 
OnError(RTCError error)44   void OnError(RTCError error) override {}
45 
state()46   DtlsTransportState state() {
47     if (states_.size() > 0) {
48       return states_[states_.size() - 1];
49     } else {
50       return DtlsTransportState::kNew;
51     }
52   }
53 
54   bool state_change_called_ = false;
55   DtlsTransportInformation info_;
56   std::vector<DtlsTransportState> states_;
57 };
58 
59 class DtlsTransportTest : public ::testing::Test {
60  public:
transport()61   DtlsTransport* transport() { return transport_.get(); }
observer()62   DtlsTransportObserverInterface* observer() { return &observer_; }
63 
CreateTransport(rtc::FakeSSLCertificate * certificate=nullptr)64   void CreateTransport(rtc::FakeSSLCertificate* certificate = nullptr) {
65     auto cricket_transport = std::make_unique<FakeDtlsTransport>(
66         "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
67     if (certificate) {
68       cricket_transport->SetRemoteSSLCertificate(certificate);
69     }
70     cricket_transport->SetSslCipherSuite(kNonsenseCipherSuite);
71     transport_ =
72         rtc::make_ref_counted<DtlsTransport>(std::move(cricket_transport));
73   }
74 
CompleteDtlsHandshake()75   void CompleteDtlsHandshake() {
76     auto fake_dtls1 = static_cast<FakeDtlsTransport*>(transport_->internal());
77     auto fake_dtls2 = std::make_unique<FakeDtlsTransport>(
78         "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
79     auto cert1 = rtc::RTCCertificate::Create(
80         rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT));
81     fake_dtls1->SetLocalCertificate(cert1);
82     auto cert2 = rtc::RTCCertificate::Create(
83         rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT));
84     fake_dtls2->SetLocalCertificate(cert2);
85     fake_dtls1->SetDestination(fake_dtls2.get());
86   }
87 
88   rtc::AutoThread main_thread_;
89   rtc::scoped_refptr<DtlsTransport> transport_;
90   TestDtlsTransportObserver observer_;
91 };
92 
TEST_F(DtlsTransportTest,CreateClearDelete)93 TEST_F(DtlsTransportTest, CreateClearDelete) {
94   auto cricket_transport = std::make_unique<FakeDtlsTransport>(
95       "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
96   auto webrtc_transport =
97       rtc::make_ref_counted<DtlsTransport>(std::move(cricket_transport));
98   ASSERT_TRUE(webrtc_transport->internal());
99   ASSERT_EQ(DtlsTransportState::kNew, webrtc_transport->Information().state());
100   webrtc_transport->Clear();
101   ASSERT_FALSE(webrtc_transport->internal());
102   ASSERT_EQ(DtlsTransportState::kClosed,
103             webrtc_transport->Information().state());
104 }
105 
TEST_F(DtlsTransportTest,EventsObservedWhenConnecting)106 TEST_F(DtlsTransportTest, EventsObservedWhenConnecting) {
107   CreateTransport();
108   transport()->RegisterObserver(observer());
109   CompleteDtlsHandshake();
110   ASSERT_TRUE_WAIT(observer_.state_change_called_, kDefaultTimeout);
111   EXPECT_THAT(
112       observer_.states_,
113       ElementsAre(  // FakeDtlsTransport doesn't signal the "connecting" state.
114                     // TODO(hta): fix FakeDtlsTransport or file bug on it.
115                     // DtlsTransportState::kConnecting,
116           DtlsTransportState::kConnected));
117 }
118 
TEST_F(DtlsTransportTest,CloseWhenClearing)119 TEST_F(DtlsTransportTest, CloseWhenClearing) {
120   CreateTransport();
121   transport()->RegisterObserver(observer());
122   CompleteDtlsHandshake();
123   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
124                    kDefaultTimeout);
125   transport()->Clear();
126   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kClosed,
127                    kDefaultTimeout);
128 }
129 
TEST_F(DtlsTransportTest,RoleAppearsOnConnect)130 TEST_F(DtlsTransportTest, RoleAppearsOnConnect) {
131   rtc::FakeSSLCertificate fake_certificate("fake data");
132   CreateTransport(&fake_certificate);
133   transport()->RegisterObserver(observer());
134   EXPECT_FALSE(transport()->Information().role());
135   CompleteDtlsHandshake();
136   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
137                    kDefaultTimeout);
138   EXPECT_TRUE(observer_.info_.role());
139   EXPECT_TRUE(transport()->Information().role());
140   EXPECT_EQ(transport()->Information().role(), DtlsTransportTlsRole::kClient);
141 }
142 
TEST_F(DtlsTransportTest,CertificateAppearsOnConnect)143 TEST_F(DtlsTransportTest, CertificateAppearsOnConnect) {
144   rtc::FakeSSLCertificate fake_certificate("fake data");
145   CreateTransport(&fake_certificate);
146   transport()->RegisterObserver(observer());
147   CompleteDtlsHandshake();
148   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
149                    kDefaultTimeout);
150   EXPECT_TRUE(observer_.info_.remote_ssl_certificates() != nullptr);
151 }
152 
TEST_F(DtlsTransportTest,CertificateDisappearsOnClose)153 TEST_F(DtlsTransportTest, CertificateDisappearsOnClose) {
154   rtc::FakeSSLCertificate fake_certificate("fake data");
155   CreateTransport(&fake_certificate);
156   transport()->RegisterObserver(observer());
157   CompleteDtlsHandshake();
158   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
159                    kDefaultTimeout);
160   EXPECT_TRUE(observer_.info_.remote_ssl_certificates() != nullptr);
161   transport()->Clear();
162   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kClosed,
163                    kDefaultTimeout);
164   EXPECT_FALSE(observer_.info_.remote_ssl_certificates());
165 }
166 
TEST_F(DtlsTransportTest,CipherSuiteVisibleWhenConnected)167 TEST_F(DtlsTransportTest, CipherSuiteVisibleWhenConnected) {
168   CreateTransport();
169   transport()->RegisterObserver(observer());
170   CompleteDtlsHandshake();
171   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
172                    kDefaultTimeout);
173   ASSERT_TRUE(observer_.info_.ssl_cipher_suite());
174   EXPECT_EQ(kNonsenseCipherSuite, *observer_.info_.ssl_cipher_suite());
175   transport()->Clear();
176   ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kClosed,
177                    kDefaultTimeout);
178   EXPECT_FALSE(observer_.info_.ssl_cipher_suite());
179 }
180 
181 }  // namespace webrtc
182