1 /*
2 * Copyright 2018 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "pc/dtls_transport.h"
12
13 #include <utility>
14 #include <vector>
15
16 #include "absl/types/optional.h"
17 #include "api/make_ref_counted.h"
18 #include "api/rtc_error.h"
19 #include "p2p/base/fake_dtls_transport.h"
20 #include "p2p/base/p2p_constants.h"
21 #include "rtc_base/fake_ssl_identity.h"
22 #include "rtc_base/gunit.h"
23 #include "rtc_base/rtc_certificate.h"
24 #include "rtc_base/ssl_identity.h"
25 #include "test/gmock.h"
26 #include "test/gtest.h"
27
28 constexpr int kDefaultTimeout = 1000; // milliseconds
29 constexpr int kNonsenseCipherSuite = 1234;
30
31 using cricket::FakeDtlsTransport;
32 using ::testing::ElementsAre;
33
34 namespace webrtc {
35
36 class TestDtlsTransportObserver : public DtlsTransportObserverInterface {
37 public:
OnStateChange(DtlsTransportInformation info)38 void OnStateChange(DtlsTransportInformation info) override {
39 state_change_called_ = true;
40 states_.push_back(info.state());
41 info_ = info;
42 }
43
OnError(RTCError error)44 void OnError(RTCError error) override {}
45
state()46 DtlsTransportState state() {
47 if (states_.size() > 0) {
48 return states_[states_.size() - 1];
49 } else {
50 return DtlsTransportState::kNew;
51 }
52 }
53
54 bool state_change_called_ = false;
55 DtlsTransportInformation info_;
56 std::vector<DtlsTransportState> states_;
57 };
58
59 class DtlsTransportTest : public ::testing::Test {
60 public:
transport()61 DtlsTransport* transport() { return transport_.get(); }
observer()62 DtlsTransportObserverInterface* observer() { return &observer_; }
63
CreateTransport(rtc::FakeSSLCertificate * certificate=nullptr)64 void CreateTransport(rtc::FakeSSLCertificate* certificate = nullptr) {
65 auto cricket_transport = std::make_unique<FakeDtlsTransport>(
66 "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
67 if (certificate) {
68 cricket_transport->SetRemoteSSLCertificate(certificate);
69 }
70 cricket_transport->SetSslCipherSuite(kNonsenseCipherSuite);
71 transport_ =
72 rtc::make_ref_counted<DtlsTransport>(std::move(cricket_transport));
73 }
74
CompleteDtlsHandshake()75 void CompleteDtlsHandshake() {
76 auto fake_dtls1 = static_cast<FakeDtlsTransport*>(transport_->internal());
77 auto fake_dtls2 = std::make_unique<FakeDtlsTransport>(
78 "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
79 auto cert1 = rtc::RTCCertificate::Create(
80 rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT));
81 fake_dtls1->SetLocalCertificate(cert1);
82 auto cert2 = rtc::RTCCertificate::Create(
83 rtc::SSLIdentity::Create("session1", rtc::KT_DEFAULT));
84 fake_dtls2->SetLocalCertificate(cert2);
85 fake_dtls1->SetDestination(fake_dtls2.get());
86 }
87
88 rtc::AutoThread main_thread_;
89 rtc::scoped_refptr<DtlsTransport> transport_;
90 TestDtlsTransportObserver observer_;
91 };
92
TEST_F(DtlsTransportTest,CreateClearDelete)93 TEST_F(DtlsTransportTest, CreateClearDelete) {
94 auto cricket_transport = std::make_unique<FakeDtlsTransport>(
95 "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
96 auto webrtc_transport =
97 rtc::make_ref_counted<DtlsTransport>(std::move(cricket_transport));
98 ASSERT_TRUE(webrtc_transport->internal());
99 ASSERT_EQ(DtlsTransportState::kNew, webrtc_transport->Information().state());
100 webrtc_transport->Clear();
101 ASSERT_FALSE(webrtc_transport->internal());
102 ASSERT_EQ(DtlsTransportState::kClosed,
103 webrtc_transport->Information().state());
104 }
105
TEST_F(DtlsTransportTest,EventsObservedWhenConnecting)106 TEST_F(DtlsTransportTest, EventsObservedWhenConnecting) {
107 CreateTransport();
108 transport()->RegisterObserver(observer());
109 CompleteDtlsHandshake();
110 ASSERT_TRUE_WAIT(observer_.state_change_called_, kDefaultTimeout);
111 EXPECT_THAT(
112 observer_.states_,
113 ElementsAre( // FakeDtlsTransport doesn't signal the "connecting" state.
114 // TODO(hta): fix FakeDtlsTransport or file bug on it.
115 // DtlsTransportState::kConnecting,
116 DtlsTransportState::kConnected));
117 }
118
TEST_F(DtlsTransportTest,CloseWhenClearing)119 TEST_F(DtlsTransportTest, CloseWhenClearing) {
120 CreateTransport();
121 transport()->RegisterObserver(observer());
122 CompleteDtlsHandshake();
123 ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
124 kDefaultTimeout);
125 transport()->Clear();
126 ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kClosed,
127 kDefaultTimeout);
128 }
129
TEST_F(DtlsTransportTest,RoleAppearsOnConnect)130 TEST_F(DtlsTransportTest, RoleAppearsOnConnect) {
131 rtc::FakeSSLCertificate fake_certificate("fake data");
132 CreateTransport(&fake_certificate);
133 transport()->RegisterObserver(observer());
134 EXPECT_FALSE(transport()->Information().role());
135 CompleteDtlsHandshake();
136 ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
137 kDefaultTimeout);
138 EXPECT_TRUE(observer_.info_.role());
139 EXPECT_TRUE(transport()->Information().role());
140 EXPECT_EQ(transport()->Information().role(), DtlsTransportTlsRole::kClient);
141 }
142
TEST_F(DtlsTransportTest,CertificateAppearsOnConnect)143 TEST_F(DtlsTransportTest, CertificateAppearsOnConnect) {
144 rtc::FakeSSLCertificate fake_certificate("fake data");
145 CreateTransport(&fake_certificate);
146 transport()->RegisterObserver(observer());
147 CompleteDtlsHandshake();
148 ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
149 kDefaultTimeout);
150 EXPECT_TRUE(observer_.info_.remote_ssl_certificates() != nullptr);
151 }
152
TEST_F(DtlsTransportTest,CertificateDisappearsOnClose)153 TEST_F(DtlsTransportTest, CertificateDisappearsOnClose) {
154 rtc::FakeSSLCertificate fake_certificate("fake data");
155 CreateTransport(&fake_certificate);
156 transport()->RegisterObserver(observer());
157 CompleteDtlsHandshake();
158 ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
159 kDefaultTimeout);
160 EXPECT_TRUE(observer_.info_.remote_ssl_certificates() != nullptr);
161 transport()->Clear();
162 ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kClosed,
163 kDefaultTimeout);
164 EXPECT_FALSE(observer_.info_.remote_ssl_certificates());
165 }
166
TEST_F(DtlsTransportTest,CipherSuiteVisibleWhenConnected)167 TEST_F(DtlsTransportTest, CipherSuiteVisibleWhenConnected) {
168 CreateTransport();
169 transport()->RegisterObserver(observer());
170 CompleteDtlsHandshake();
171 ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kConnected,
172 kDefaultTimeout);
173 ASSERT_TRUE(observer_.info_.ssl_cipher_suite());
174 EXPECT_EQ(kNonsenseCipherSuite, *observer_.info_.ssl_cipher_suite());
175 transport()->Clear();
176 ASSERT_TRUE_WAIT(observer_.state() == DtlsTransportState::kClosed,
177 kDefaultTimeout);
178 EXPECT_FALSE(observer_.info_.ssl_cipher_suite());
179 }
180
181 } // namespace webrtc
182