1 /*
2 * Copyright 2019 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "pc/sctp_transport.h"
12
13 #include <utility>
14 #include <vector>
15
16 #include "absl/memory/memory.h"
17 #include "absl/types/optional.h"
18 #include "api/dtls_transport_interface.h"
19 #include "api/transport/data_channel_transport_interface.h"
20 #include "media/base/media_channel.h"
21 #include "p2p/base/fake_dtls_transport.h"
22 #include "p2p/base/p2p_constants.h"
23 #include "p2p/base/packet_transport_internal.h"
24 #include "pc/dtls_transport.h"
25 #include "rtc_base/copy_on_write_buffer.h"
26 #include "rtc_base/gunit.h"
27 #include "test/gmock.h"
28 #include "test/gtest.h"
29
30 constexpr int kDefaultTimeout = 1000; // milliseconds
31 constexpr int kTestMaxSctpStreams = 1234;
32
33 using cricket::FakeDtlsTransport;
34 using ::testing::ElementsAre;
35
36 namespace webrtc {
37
38 namespace {
39
40 class FakeCricketSctpTransport : public cricket::SctpTransportInternal {
41 public:
SetOnConnectedCallback(std::function<void ()> callback)42 void SetOnConnectedCallback(std::function<void()> callback) override {
43 on_connected_callback_ = std::move(callback);
44 }
SetDataChannelSink(DataChannelSink * sink)45 void SetDataChannelSink(DataChannelSink* sink) override {}
SetDtlsTransport(rtc::PacketTransportInternal * transport)46 void SetDtlsTransport(rtc::PacketTransportInternal* transport) override {}
Start(int local_port,int remote_port,int max_message_size)47 bool Start(int local_port, int remote_port, int max_message_size) override {
48 return true;
49 }
OpenStream(int sid)50 bool OpenStream(int sid) override { return true; }
ResetStream(int sid)51 bool ResetStream(int sid) override { return true; }
SendData(int sid,const SendDataParams & params,const rtc::CopyOnWriteBuffer & payload,cricket::SendDataResult * result=nullptr)52 bool SendData(int sid,
53 const SendDataParams& params,
54 const rtc::CopyOnWriteBuffer& payload,
55 cricket::SendDataResult* result = nullptr) override {
56 return true;
57 }
ReadyToSendData()58 bool ReadyToSendData() override { return true; }
set_debug_name_for_testing(const char * debug_name)59 void set_debug_name_for_testing(const char* debug_name) override {}
max_message_size() const60 int max_message_size() const override { return 0; }
max_outbound_streams() const61 absl::optional<int> max_outbound_streams() const override {
62 return max_outbound_streams_;
63 }
max_inbound_streams() const64 absl::optional<int> max_inbound_streams() const override {
65 return max_inbound_streams_;
66 }
67
SendSignalAssociationChangeCommunicationUp()68 void SendSignalAssociationChangeCommunicationUp() {
69 ASSERT_TRUE(on_connected_callback_);
70 on_connected_callback_();
71 }
72
set_max_outbound_streams(int streams)73 void set_max_outbound_streams(int streams) {
74 max_outbound_streams_ = streams;
75 }
set_max_inbound_streams(int streams)76 void set_max_inbound_streams(int streams) { max_inbound_streams_ = streams; }
77
78 private:
79 absl::optional<int> max_outbound_streams_;
80 absl::optional<int> max_inbound_streams_;
81 std::function<void()> on_connected_callback_;
82 };
83
84 } // namespace
85
86 class TestSctpTransportObserver : public SctpTransportObserverInterface {
87 public:
TestSctpTransportObserver()88 TestSctpTransportObserver() : info_(SctpTransportState::kNew) {}
89
OnStateChange(SctpTransportInformation info)90 void OnStateChange(SctpTransportInformation info) override {
91 info_ = info;
92 states_.push_back(info.state());
93 }
94
State()95 SctpTransportState State() {
96 if (states_.size() > 0) {
97 return states_[states_.size() - 1];
98 } else {
99 return SctpTransportState::kNew;
100 }
101 }
102
States()103 const std::vector<SctpTransportState>& States() { return states_; }
104
LastReceivedInformation()105 const SctpTransportInformation LastReceivedInformation() { return info_; }
106
107 private:
108 std::vector<SctpTransportState> states_;
109 SctpTransportInformation info_;
110 };
111
112 class SctpTransportTest : public ::testing::Test {
113 public:
transport()114 SctpTransport* transport() { return transport_.get(); }
observer()115 SctpTransportObserverInterface* observer() { return &observer_; }
116
CreateTransport()117 void CreateTransport() {
118 auto cricket_sctp_transport =
119 absl::WrapUnique(new FakeCricketSctpTransport());
120 transport_ =
121 rtc::make_ref_counted<SctpTransport>(std::move(cricket_sctp_transport));
122 }
123
AddDtlsTransport()124 void AddDtlsTransport() {
125 std::unique_ptr<cricket::DtlsTransportInternal> cricket_transport =
126 std::make_unique<FakeDtlsTransport>(
127 "audio", cricket::ICE_CANDIDATE_COMPONENT_RTP);
128 dtls_transport_ =
129 rtc::make_ref_counted<DtlsTransport>(std::move(cricket_transport));
130 transport_->SetDtlsTransport(dtls_transport_);
131 }
132
CompleteSctpHandshake()133 void CompleteSctpHandshake() {
134 // The computed MaxChannels shall be the minimum of the outgoing
135 // and incoming # of streams.
136 CricketSctpTransport()->set_max_outbound_streams(kTestMaxSctpStreams);
137 CricketSctpTransport()->set_max_inbound_streams(kTestMaxSctpStreams + 1);
138 CricketSctpTransport()->SendSignalAssociationChangeCommunicationUp();
139 }
140
CricketSctpTransport()141 FakeCricketSctpTransport* CricketSctpTransport() {
142 return static_cast<FakeCricketSctpTransport*>(transport_->internal());
143 }
144
145 rtc::AutoThread main_thread_;
146 rtc::scoped_refptr<SctpTransport> transport_;
147 rtc::scoped_refptr<DtlsTransport> dtls_transport_;
148 TestSctpTransportObserver observer_;
149 };
150
TEST(SctpTransportSimpleTest,CreateClearDelete)151 TEST(SctpTransportSimpleTest, CreateClearDelete) {
152 rtc::AutoThread main_thread;
153 std::unique_ptr<cricket::SctpTransportInternal> fake_cricket_sctp_transport =
154 absl::WrapUnique(new FakeCricketSctpTransport());
155 rtc::scoped_refptr<SctpTransport> sctp_transport =
156 rtc::make_ref_counted<SctpTransport>(
157 std::move(fake_cricket_sctp_transport));
158 ASSERT_TRUE(sctp_transport->internal());
159 ASSERT_EQ(SctpTransportState::kNew, sctp_transport->Information().state());
160 sctp_transport->Clear();
161 ASSERT_FALSE(sctp_transport->internal());
162 ASSERT_EQ(SctpTransportState::kClosed, sctp_transport->Information().state());
163 }
164
TEST_F(SctpTransportTest,EventsObservedWhenConnecting)165 TEST_F(SctpTransportTest, EventsObservedWhenConnecting) {
166 CreateTransport();
167 transport()->RegisterObserver(observer());
168 AddDtlsTransport();
169 CompleteSctpHandshake();
170 ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
171 kDefaultTimeout);
172 EXPECT_THAT(observer_.States(), ElementsAre(SctpTransportState::kConnecting,
173 SctpTransportState::kConnected));
174 }
175
TEST_F(SctpTransportTest,CloseWhenClearing)176 TEST_F(SctpTransportTest, CloseWhenClearing) {
177 CreateTransport();
178 transport()->RegisterObserver(observer());
179 AddDtlsTransport();
180 CompleteSctpHandshake();
181 ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
182 kDefaultTimeout);
183 transport()->Clear();
184 ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(),
185 kDefaultTimeout);
186 }
187
TEST_F(SctpTransportTest,MaxChannelsSignalled)188 TEST_F(SctpTransportTest, MaxChannelsSignalled) {
189 CreateTransport();
190 transport()->RegisterObserver(observer());
191 AddDtlsTransport();
192 EXPECT_FALSE(transport()->Information().MaxChannels());
193 EXPECT_FALSE(observer_.LastReceivedInformation().MaxChannels());
194 CompleteSctpHandshake();
195 ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
196 kDefaultTimeout);
197 EXPECT_TRUE(transport()->Information().MaxChannels());
198 EXPECT_EQ(kTestMaxSctpStreams, *(transport()->Information().MaxChannels()));
199 EXPECT_TRUE(observer_.LastReceivedInformation().MaxChannels());
200 EXPECT_EQ(kTestMaxSctpStreams,
201 *(observer_.LastReceivedInformation().MaxChannels()));
202 }
203
TEST_F(SctpTransportTest,CloseWhenTransportCloses)204 TEST_F(SctpTransportTest, CloseWhenTransportCloses) {
205 CreateTransport();
206 transport()->RegisterObserver(observer());
207 AddDtlsTransport();
208 CompleteSctpHandshake();
209 ASSERT_EQ_WAIT(SctpTransportState::kConnected, observer_.State(),
210 kDefaultTimeout);
211 static_cast<cricket::FakeDtlsTransport*>(dtls_transport_->internal())
212 ->SetDtlsState(DtlsTransportState::kClosed);
213 ASSERT_EQ_WAIT(SctpTransportState::kClosed, observer_.State(),
214 kDefaultTimeout);
215 }
216
217 } // namespace webrtc
218