1 /*
2 * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10 #include <cstdint>
11 #include <deque>
12 #include <memory>
13 #include <string>
14 #include <utility>
15 #include <vector>
16
17 #include "absl/memory/memory.h"
18 #include "absl/strings/string_view.h"
19 #include "absl/types/optional.h"
20 #include "api/array_view.h"
21 #include "api/task_queue/pending_task_safety_flag.h"
22 #include "api/task_queue/task_queue_base.h"
23 #include "api/test/create_network_emulation_manager.h"
24 #include "api/test/network_emulation_manager.h"
25 #include "api/units/time_delta.h"
26 #include "call/simulated_network.h"
27 #include "net/dcsctp/public/dcsctp_options.h"
28 #include "net/dcsctp/public/dcsctp_socket.h"
29 #include "net/dcsctp/public/types.h"
30 #include "net/dcsctp/socket/dcsctp_socket.h"
31 #include "net/dcsctp/testing/testing_macros.h"
32 #include "net/dcsctp/timer/task_queue_timeout.h"
33 #include "rtc_base/copy_on_write_buffer.h"
34 #include "rtc_base/gunit.h"
35 #include "rtc_base/logging.h"
36 #include "rtc_base/socket_address.h"
37 #include "rtc_base/strings/string_format.h"
38 #include "rtc_base/time_utils.h"
39 #include "test/gmock.h"
40
41 #if !defined(WEBRTC_ANDROID) && defined(NDEBUG) && \
42 !defined(THREAD_SANITIZER) && !defined(MEMORY_SANITIZER)
43 #define DCSCTP_NDEBUG_TEST(t) t
44 #else
45 // In debug mode, and when MSAN or TSAN sanitizers are enabled, these tests are
46 // too expensive to run due to extensive consistency checks that iterate on all
47 // outstanding chunks. Same with low-end Android devices, which have
48 // difficulties with these tests.
49 #define DCSCTP_NDEBUG_TEST(t) DISABLED_##t
50 #endif
51
52 namespace dcsctp {
53 namespace {
54 using ::testing::AllOf;
55 using ::testing::Ge;
56 using ::testing::Le;
57 using ::testing::SizeIs;
58
59 constexpr StreamID kStreamId(1);
60 constexpr PPID kPpid(53);
61 constexpr size_t kSmallPayloadSize = 10;
62 constexpr size_t kLargePayloadSize = 10000;
63 constexpr size_t kHugePayloadSize = 262144;
64 constexpr size_t kBufferedAmountLowThreshold = kLargePayloadSize * 2;
65 constexpr webrtc::TimeDelta kPrintBandwidthDuration =
66 webrtc::TimeDelta::Seconds(1);
67 constexpr webrtc::TimeDelta kBenchmarkRuntime(webrtc::TimeDelta::Seconds(10));
68 constexpr webrtc::TimeDelta kAWhile(webrtc::TimeDelta::Seconds(1));
69
GetUniqueSeed()70 inline int GetUniqueSeed() {
71 static int seed = 0;
72 return ++seed;
73 }
74
MakeOptionsForTest()75 DcSctpOptions MakeOptionsForTest() {
76 DcSctpOptions options;
77
78 // Throughput numbers are affected by the MTU. Ensure it's constant.
79 options.mtu = 1200;
80
81 // By disabling the heartbeat interval, there will no timers at all running
82 // when the socket is idle, which makes it easy to just continue the test
83 // until there are no more scheduled tasks. Note that it _will_ run for longer
84 // than necessary as timers aren't cancelled when they are stopped (as that's
85 // not supported), but it's still simulated time and passes quickly.
86 options.heartbeat_interval = DurationMs(0);
87 return options;
88 }
89
90 // When doing throughput tests, knowing what each actor should do.
91 enum class ActorMode {
92 kAtRest,
93 kThroughputSender,
94 kThroughputReceiver,
95 kLimitedRetransmissionSender,
96 };
97
98 // An abstraction around EmulatedEndpoint, representing a bound socket that
99 // will send its packet to a given destination.
100 class BoundSocket : public webrtc::EmulatedNetworkReceiverInterface {
101 public:
Bind(webrtc::EmulatedEndpoint * endpoint)102 void Bind(webrtc::EmulatedEndpoint* endpoint) {
103 endpoint_ = endpoint;
104 uint16_t port = endpoint->BindReceiver(0, this).value();
105 source_address_ =
106 rtc::SocketAddress(endpoint_->GetPeerLocalAddress(), port);
107 }
108
SetDestination(const BoundSocket & socket)109 void SetDestination(const BoundSocket& socket) {
110 dest_address_ = socket.source_address_;
111 }
112
SetReceiver(std::function<void (rtc::CopyOnWriteBuffer)> receiver)113 void SetReceiver(std::function<void(rtc::CopyOnWriteBuffer)> receiver) {
114 receiver_ = std::move(receiver);
115 }
116
SendPacket(rtc::ArrayView<const uint8_t> data)117 void SendPacket(rtc::ArrayView<const uint8_t> data) {
118 endpoint_->SendPacket(source_address_, dest_address_,
119 rtc::CopyOnWriteBuffer(data.data(), data.size()));
120 }
121
122 private:
123 // Implementation of `webrtc::EmulatedNetworkReceiverInterface`.
OnPacketReceived(webrtc::EmulatedIpPacket packet)124 void OnPacketReceived(webrtc::EmulatedIpPacket packet) override {
125 receiver_(std::move(packet.data));
126 }
127
128 std::function<void(rtc::CopyOnWriteBuffer)> receiver_;
129 webrtc::EmulatedEndpoint* endpoint_ = nullptr;
130 rtc::SocketAddress source_address_;
131 rtc::SocketAddress dest_address_;
132 };
133
134 // Sends at a constant rate but with random packet sizes.
135 class SctpActor : public DcSctpSocketCallbacks {
136 public:
SctpActor(absl::string_view name,BoundSocket & emulated_socket,const DcSctpOptions & sctp_options)137 SctpActor(absl::string_view name,
138 BoundSocket& emulated_socket,
139 const DcSctpOptions& sctp_options)
140 : log_prefix_(std::string(name) + ": "),
141 thread_(rtc::Thread::Current()),
142 emulated_socket_(emulated_socket),
143 timeout_factory_(
144 *thread_,
145 [this]() { return TimeMillis(); },
__anon65f5f2e00302(dcsctp::TimeoutID timeout_id) 146 [this](dcsctp::TimeoutID timeout_id) {
147 sctp_socket_.HandleTimeout(timeout_id);
148 }),
149 random_(GetUniqueSeed()),
150 sctp_socket_(name, *this, nullptr, sctp_options),
151 last_bandwidth_printout_(TimeMs(TimeMillis())) {
__anon65f5f2e00402(rtc::CopyOnWriteBuffer buf) 152 emulated_socket.SetReceiver([this](rtc::CopyOnWriteBuffer buf) {
153 // The receiver will be executed on the NetworkEmulation task queue, but
154 // the dcSCTP socket is owned by `thread_` and is not thread-safe.
155 thread_->PostTask([this, buf] { this->sctp_socket_.ReceivePacket(buf); });
156 });
157 }
158
PrintBandwidth()159 void PrintBandwidth() {
160 TimeMs now = TimeMillis();
161 DurationMs duration = now - last_bandwidth_printout_;
162
163 double bitrate_mbps =
164 static_cast<double>(received_bytes_ * 8) / *duration / 1000;
165 RTC_LOG(LS_INFO) << log_prefix()
166 << rtc::StringFormat("Received %0.2f Mbps", bitrate_mbps);
167
168 received_bitrate_mbps_.push_back(bitrate_mbps);
169 received_bytes_ = 0;
170 last_bandwidth_printout_ = now;
171 // Print again in a second.
172 if (mode_ == ActorMode::kThroughputReceiver) {
173 thread_->PostDelayedTask(
174 SafeTask(safety_.flag(), [this] { PrintBandwidth(); }),
175 kPrintBandwidthDuration);
176 }
177 }
178
SendPacket(rtc::ArrayView<const uint8_t> data)179 void SendPacket(rtc::ArrayView<const uint8_t> data) override {
180 emulated_socket_.SendPacket(data);
181 }
182
CreateTimeout(webrtc::TaskQueueBase::DelayPrecision precision)183 std::unique_ptr<Timeout> CreateTimeout(
184 webrtc::TaskQueueBase::DelayPrecision precision) override {
185 return timeout_factory_.CreateTimeout(precision);
186 }
187
TimeMillis()188 TimeMs TimeMillis() override { return TimeMs(rtc::TimeMillis()); }
189
GetRandomInt(uint32_t low,uint32_t high)190 uint32_t GetRandomInt(uint32_t low, uint32_t high) override {
191 return random_.Rand(low, high);
192 }
193
OnMessageReceived(DcSctpMessage message)194 void OnMessageReceived(DcSctpMessage message) override {
195 received_bytes_ += message.payload().size();
196 last_received_message_ = std::move(message);
197 }
198
OnError(ErrorKind error,absl::string_view message)199 void OnError(ErrorKind error, absl::string_view message) override {
200 RTC_LOG(LS_WARNING) << log_prefix() << "Socket error: " << ToString(error)
201 << "; " << message;
202 }
203
OnAborted(ErrorKind error,absl::string_view message)204 void OnAborted(ErrorKind error, absl::string_view message) override {
205 RTC_LOG(LS_ERROR) << log_prefix() << "Socket abort: " << ToString(error)
206 << "; " << message;
207 }
208
OnConnected()209 void OnConnected() override {}
210
OnClosed()211 void OnClosed() override {}
212
OnConnectionRestarted()213 void OnConnectionRestarted() override {}
214
OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,absl::string_view reason)215 void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,
216 absl::string_view reason) override {}
217
OnStreamsResetPerformed(rtc::ArrayView<const StreamID> outgoing_streams)218 void OnStreamsResetPerformed(
219 rtc::ArrayView<const StreamID> outgoing_streams) override {}
220
OnIncomingStreamsReset(rtc::ArrayView<const StreamID> incoming_streams)221 void OnIncomingStreamsReset(
222 rtc::ArrayView<const StreamID> incoming_streams) override {}
223
NotifyOutgoingMessageBufferEmpty()224 void NotifyOutgoingMessageBufferEmpty() override {}
225
OnBufferedAmountLow(StreamID stream_id)226 void OnBufferedAmountLow(StreamID stream_id) override {
227 if (mode_ == ActorMode::kThroughputSender) {
228 std::vector<uint8_t> payload(kHugePayloadSize);
229 sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)),
230 SendOptions());
231
232 } else if (mode_ == ActorMode::kLimitedRetransmissionSender) {
233 while (sctp_socket_.buffered_amount(kStreamId) <
234 kBufferedAmountLowThreshold * 2) {
235 SendOptions send_options;
236 send_options.max_retransmissions = 0;
237 sctp_socket_.Send(
238 DcSctpMessage(kStreamId, kPpid,
239 std::vector<uint8_t>(kLargePayloadSize)),
240 send_options);
241
242 send_options.max_retransmissions = absl::nullopt;
243 sctp_socket_.Send(
244 DcSctpMessage(kStreamId, kPpid,
245 std::vector<uint8_t>(kSmallPayloadSize)),
246 send_options);
247 }
248 }
249 }
250
ConsumeReceivedMessage()251 absl::optional<DcSctpMessage> ConsumeReceivedMessage() {
252 if (!last_received_message_.has_value()) {
253 return absl::nullopt;
254 }
255 DcSctpMessage ret = *std::move(last_received_message_);
256 last_received_message_ = absl::nullopt;
257 return ret;
258 }
259
sctp_socket()260 DcSctpSocket& sctp_socket() { return sctp_socket_; }
261
SetActorMode(ActorMode mode)262 void SetActorMode(ActorMode mode) {
263 mode_ = mode;
264 if (mode_ == ActorMode::kThroughputSender) {
265 sctp_socket_.SetBufferedAmountLowThreshold(kStreamId,
266 kBufferedAmountLowThreshold);
267 std::vector<uint8_t> payload(kHugePayloadSize);
268 sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)),
269 SendOptions());
270
271 } else if (mode_ == ActorMode::kLimitedRetransmissionSender) {
272 sctp_socket_.SetBufferedAmountLowThreshold(kStreamId,
273 kBufferedAmountLowThreshold);
274 std::vector<uint8_t> payload(kHugePayloadSize);
275 sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)),
276 SendOptions());
277
278 } else if (mode == ActorMode::kThroughputReceiver) {
279 thread_->PostDelayedTask(
280 SafeTask(safety_.flag(), [this] { PrintBandwidth(); }),
281 kPrintBandwidthDuration);
282 }
283 }
284
285 // Returns the average bitrate, stripping the first `remove_first_n` that
286 // represent the time it took to ramp up the congestion control algorithm.
avg_received_bitrate_mbps(size_t remove_first_n=3) const287 double avg_received_bitrate_mbps(size_t remove_first_n = 3) const {
288 std::vector<double> bitrates = received_bitrate_mbps_;
289 bitrates.erase(bitrates.begin(), bitrates.begin() + remove_first_n);
290
291 double sum = 0;
292 for (double bitrate : bitrates) {
293 sum += bitrate;
294 }
295
296 return sum / bitrates.size();
297 }
298
299 private:
log_prefix() const300 std::string log_prefix() const {
301 rtc::StringBuilder sb;
302 sb << log_prefix_;
303 sb << rtc::TimeMillis();
304 sb << ": ";
305 return sb.Release();
306 }
307
308 ActorMode mode_ = ActorMode::kAtRest;
309 const std::string log_prefix_;
310 rtc::Thread* thread_;
311 BoundSocket& emulated_socket_;
312 TaskQueueTimeoutFactory timeout_factory_;
313 webrtc::Random random_;
314 DcSctpSocket sctp_socket_;
315 size_t received_bytes_ = 0;
316 absl::optional<DcSctpMessage> last_received_message_;
317 TimeMs last_bandwidth_printout_;
318 // Per-second received bitrates, in Mbps
319 std::vector<double> received_bitrate_mbps_;
320 webrtc::ScopedTaskSafety safety_;
321 };
322
323 class DcSctpSocketNetworkTest : public testing::Test {
324 protected:
DcSctpSocketNetworkTest()325 DcSctpSocketNetworkTest()
326 : options_(MakeOptionsForTest()),
327 emulation_(webrtc::CreateNetworkEmulationManager(
328 webrtc::TimeMode::kSimulated)) {}
329
MakeNetwork(const webrtc::BuiltInNetworkBehaviorConfig & config)330 void MakeNetwork(const webrtc::BuiltInNetworkBehaviorConfig& config) {
331 webrtc::EmulatedEndpoint* endpoint_a =
332 emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig());
333 webrtc::EmulatedEndpoint* endpoint_z =
334 emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig());
335
336 webrtc::EmulatedNetworkNode* node1 = emulation_->CreateEmulatedNode(config);
337 webrtc::EmulatedNetworkNode* node2 = emulation_->CreateEmulatedNode(config);
338
339 emulation_->CreateRoute(endpoint_a, {node1}, endpoint_z);
340 emulation_->CreateRoute(endpoint_z, {node2}, endpoint_a);
341
342 emulated_socket_a_.Bind(endpoint_a);
343 emulated_socket_z_.Bind(endpoint_z);
344
345 emulated_socket_a_.SetDestination(emulated_socket_z_);
346 emulated_socket_z_.SetDestination(emulated_socket_a_);
347 }
348
Sleep(webrtc::TimeDelta duration)349 void Sleep(webrtc::TimeDelta duration) {
350 // Sleep in one-millisecond increments, to let timers expire when expected.
351 for (int i = 0; i < duration.ms(); ++i) {
352 emulation_->time_controller()->AdvanceTime(webrtc::TimeDelta::Millis(1));
353 }
354 }
355
356 DcSctpOptions options_;
357 std::unique_ptr<webrtc::NetworkEmulationManager> emulation_;
358 BoundSocket emulated_socket_a_;
359 BoundSocket emulated_socket_z_;
360 };
361
TEST_F(DcSctpSocketNetworkTest,CanConnectAndShutdown)362 TEST_F(DcSctpSocketNetworkTest, CanConnectAndShutdown) {
363 webrtc::BuiltInNetworkBehaviorConfig pipe_config;
364 MakeNetwork(pipe_config);
365
366 SctpActor sender("A", emulated_socket_a_, options_);
367 SctpActor receiver("Z", emulated_socket_z_, options_);
368 EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed);
369
370 sender.sctp_socket().Connect();
371 Sleep(kAWhile);
372 EXPECT_THAT(sender.sctp_socket().state(), SocketState::kConnected);
373
374 sender.sctp_socket().Shutdown();
375 Sleep(kAWhile);
376 EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed);
377 }
378
TEST_F(DcSctpSocketNetworkTest,CanSendLargeMessage)379 TEST_F(DcSctpSocketNetworkTest, CanSendLargeMessage) {
380 webrtc::BuiltInNetworkBehaviorConfig pipe_config;
381 pipe_config.queue_delay_ms = 30;
382 MakeNetwork(pipe_config);
383
384 SctpActor sender("A", emulated_socket_a_, options_);
385 SctpActor receiver("Z", emulated_socket_z_, options_);
386 sender.sctp_socket().Connect();
387
388 constexpr size_t kPayloadSize = 100 * 1024;
389
390 std::vector<uint8_t> payload(kPayloadSize);
391 sender.sctp_socket().Send(DcSctpMessage(kStreamId, kPpid, payload),
392 SendOptions());
393
394 Sleep(kAWhile);
395
396 ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage message,
397 receiver.ConsumeReceivedMessage());
398
399 EXPECT_THAT(message.payload(), SizeIs(kPayloadSize));
400
401 sender.sctp_socket().Shutdown();
402 Sleep(kAWhile);
403 }
404
TEST_F(DcSctpSocketNetworkTest,CanSendMessagesReliablyWithLowBandwidth)405 TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithLowBandwidth) {
406 webrtc::BuiltInNetworkBehaviorConfig pipe_config;
407 pipe_config.queue_delay_ms = 30;
408 pipe_config.link_capacity_kbps = 1000;
409 MakeNetwork(pipe_config);
410
411 SctpActor sender("A", emulated_socket_a_, options_);
412 SctpActor receiver("Z", emulated_socket_z_, options_);
413 sender.sctp_socket().Connect();
414
415 sender.SetActorMode(ActorMode::kThroughputSender);
416 receiver.SetActorMode(ActorMode::kThroughputReceiver);
417
418 Sleep(kBenchmarkRuntime);
419 sender.SetActorMode(ActorMode::kAtRest);
420 receiver.SetActorMode(ActorMode::kAtRest);
421
422 Sleep(kAWhile);
423
424 sender.sctp_socket().Shutdown();
425
426 Sleep(kAWhile);
427
428 // Verify that the bitrates are in the range of 0.5-1.0 Mbps.
429 double bitrate = receiver.avg_received_bitrate_mbps();
430 EXPECT_THAT(bitrate, AllOf(Ge(0.5), Le(1.0)));
431 }
432
TEST_F(DcSctpSocketNetworkTest,DCSCTP_NDEBUG_TEST (CanSendMessagesReliablyWithMediumBandwidth))433 TEST_F(DcSctpSocketNetworkTest,
434 DCSCTP_NDEBUG_TEST(CanSendMessagesReliablyWithMediumBandwidth)) {
435 webrtc::BuiltInNetworkBehaviorConfig pipe_config;
436 pipe_config.queue_delay_ms = 30;
437 pipe_config.link_capacity_kbps = 18000;
438 MakeNetwork(pipe_config);
439
440 SctpActor sender("A", emulated_socket_a_, options_);
441 SctpActor receiver("Z", emulated_socket_z_, options_);
442 sender.sctp_socket().Connect();
443
444 sender.SetActorMode(ActorMode::kThroughputSender);
445 receiver.SetActorMode(ActorMode::kThroughputReceiver);
446
447 Sleep(kBenchmarkRuntime);
448 sender.SetActorMode(ActorMode::kAtRest);
449 receiver.SetActorMode(ActorMode::kAtRest);
450
451 Sleep(kAWhile);
452
453 sender.sctp_socket().Shutdown();
454
455 Sleep(kAWhile);
456
457 // Verify that the bitrates are in the range of 16-18 Mbps.
458 double bitrate = receiver.avg_received_bitrate_mbps();
459 EXPECT_THAT(bitrate, AllOf(Ge(16), Le(18)));
460 }
461
TEST_F(DcSctpSocketNetworkTest,CanSendMessagesReliablyWithMuchPacketLoss)462 TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithMuchPacketLoss) {
463 webrtc::BuiltInNetworkBehaviorConfig config;
464 config.queue_delay_ms = 30;
465 config.loss_percent = 1;
466 MakeNetwork(config);
467
468 SctpActor sender("A", emulated_socket_a_, options_);
469 SctpActor receiver("Z", emulated_socket_z_, options_);
470 sender.sctp_socket().Connect();
471
472 sender.SetActorMode(ActorMode::kThroughputSender);
473 receiver.SetActorMode(ActorMode::kThroughputReceiver);
474
475 Sleep(kBenchmarkRuntime);
476 sender.SetActorMode(ActorMode::kAtRest);
477 receiver.SetActorMode(ActorMode::kAtRest);
478
479 Sleep(kAWhile);
480
481 sender.sctp_socket().Shutdown();
482
483 Sleep(kAWhile);
484
485 // TCP calculator gives: 1200 MTU, 60ms RTT and 1% packet loss -> 1.6Mbps.
486 // This test is doing slightly better (doesn't have any additional header
487 // overhead etc). Verify that the bitrates are in the range of 1.5-2.5 Mbps.
488 double bitrate = receiver.avg_received_bitrate_mbps();
489 EXPECT_THAT(bitrate, AllOf(Ge(1.5), Le(2.5)));
490 }
491
TEST_F(DcSctpSocketNetworkTest,DCSCTP_NDEBUG_TEST (HasHighBandwidth))492 TEST_F(DcSctpSocketNetworkTest, DCSCTP_NDEBUG_TEST(HasHighBandwidth)) {
493 webrtc::BuiltInNetworkBehaviorConfig pipe_config;
494 pipe_config.queue_delay_ms = 30;
495 MakeNetwork(pipe_config);
496
497 SctpActor sender("A", emulated_socket_a_, options_);
498 SctpActor receiver("Z", emulated_socket_z_, options_);
499 sender.sctp_socket().Connect();
500
501 sender.SetActorMode(ActorMode::kThroughputSender);
502 receiver.SetActorMode(ActorMode::kThroughputReceiver);
503
504 Sleep(kBenchmarkRuntime);
505
506 sender.SetActorMode(ActorMode::kAtRest);
507 receiver.SetActorMode(ActorMode::kAtRest);
508 Sleep(kAWhile);
509
510 sender.sctp_socket().Shutdown();
511 Sleep(kAWhile);
512
513 // Verify that the bitrate is in the range of 540-640 Mbps
514 double bitrate = receiver.avg_received_bitrate_mbps();
515 EXPECT_THAT(bitrate, AllOf(Ge(520), Le(640)));
516 }
517 } // namespace
518 } // namespace dcsctp
519