1 /*
2 * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10 #include "net/dcsctp/socket/dcsctp_socket.h"
11
12 #include <cstdint>
13 #include <deque>
14 #include <memory>
15 #include <string>
16 #include <utility>
17 #include <vector>
18
19 #include "absl/flags/flag.h"
20 #include "absl/memory/memory.h"
21 #include "absl/strings/string_view.h"
22 #include "absl/types/optional.h"
23 #include "api/array_view.h"
24 #include "net/dcsctp/common/handover_testing.h"
25 #include "net/dcsctp/packet/chunk/chunk.h"
26 #include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
27 #include "net/dcsctp/packet/chunk/data_chunk.h"
28 #include "net/dcsctp/packet/chunk/data_common.h"
29 #include "net/dcsctp/packet/chunk/error_chunk.h"
30 #include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
31 #include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
32 #include "net/dcsctp/packet/chunk/idata_chunk.h"
33 #include "net/dcsctp/packet/chunk/init_chunk.h"
34 #include "net/dcsctp/packet/chunk/sack_chunk.h"
35 #include "net/dcsctp/packet/chunk/shutdown_chunk.h"
36 #include "net/dcsctp/packet/error_cause/error_cause.h"
37 #include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h"
38 #include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h"
39 #include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
40 #include "net/dcsctp/packet/parameter/parameter.h"
41 #include "net/dcsctp/packet/sctp_packet.h"
42 #include "net/dcsctp/packet/tlv_trait.h"
43 #include "net/dcsctp/public/dcsctp_message.h"
44 #include "net/dcsctp/public/dcsctp_options.h"
45 #include "net/dcsctp/public/dcsctp_socket.h"
46 #include "net/dcsctp/public/text_pcap_packet_observer.h"
47 #include "net/dcsctp/public/types.h"
48 #include "net/dcsctp/rx/reassembly_queue.h"
49 #include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h"
50 #include "net/dcsctp/testing/testing_macros.h"
51 #include "rtc_base/gunit.h"
52 #include "test/gmock.h"
53
54 ABSL_FLAG(bool, dcsctp_capture_packets, false, "Print packet capture.");
55
56 namespace dcsctp {
57 namespace {
58 using ::testing::_;
59 using ::testing::AllOf;
60 using ::testing::ElementsAre;
61 using ::testing::HasSubstr;
62 using ::testing::IsEmpty;
63 using ::testing::SizeIs;
64 using ::testing::UnorderedElementsAre;
65
66 constexpr SendOptions kSendOptions;
67 constexpr size_t kLargeMessageSize = DcSctpOptions::kMaxSafeMTUSize * 20;
68 constexpr size_t kSmallMessageSize = 10;
69 constexpr int kMaxBurstPackets = 4;
70
71 MATCHER_P(HasDataChunkWithStreamId, stream_id, "") {
72 absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
73 if (!packet.has_value()) {
74 *result_listener << "data didn't parse as an SctpPacket";
75 return false;
76 }
77
78 if (packet->descriptors()[0].type != DataChunk::kType) {
79 *result_listener << "the first chunk in the packet is not a data chunk";
80 return false;
81 }
82
83 absl::optional<DataChunk> dc =
84 DataChunk::Parse(packet->descriptors()[0].data);
85 if (!dc.has_value()) {
86 *result_listener << "The first chunk didn't parse as a data chunk";
87 return false;
88 }
89
90 if (dc->stream_id() != stream_id) {
91 *result_listener << "the stream_id is " << *dc->stream_id();
92 return false;
93 }
94
95 return true;
96 }
97
98 MATCHER_P(HasDataChunkWithPPID, ppid, "") {
99 absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
100 if (!packet.has_value()) {
101 *result_listener << "data didn't parse as an SctpPacket";
102 return false;
103 }
104
105 if (packet->descriptors()[0].type != DataChunk::kType) {
106 *result_listener << "the first chunk in the packet is not a data chunk";
107 return false;
108 }
109
110 absl::optional<DataChunk> dc =
111 DataChunk::Parse(packet->descriptors()[0].data);
112 if (!dc.has_value()) {
113 *result_listener << "The first chunk didn't parse as a data chunk";
114 return false;
115 }
116
117 if (dc->ppid() != ppid) {
118 *result_listener << "the ppid is " << *dc->ppid();
119 return false;
120 }
121
122 return true;
123 }
124
125 MATCHER_P(HasDataChunkWithSsn, ssn, "") {
126 absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
127 if (!packet.has_value()) {
128 *result_listener << "data didn't parse as an SctpPacket";
129 return false;
130 }
131
132 if (packet->descriptors()[0].type != DataChunk::kType) {
133 *result_listener << "the first chunk in the packet is not a data chunk";
134 return false;
135 }
136
137 absl::optional<DataChunk> dc =
138 DataChunk::Parse(packet->descriptors()[0].data);
139 if (!dc.has_value()) {
140 *result_listener << "The first chunk didn't parse as a data chunk";
141 return false;
142 }
143
144 if (dc->ssn() != ssn) {
145 *result_listener << "the ssn is " << *dc->ssn();
146 return false;
147 }
148
149 return true;
150 }
151
152 MATCHER_P(HasDataChunkWithMid, mid, "") {
153 absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
154 if (!packet.has_value()) {
155 *result_listener << "data didn't parse as an SctpPacket";
156 return false;
157 }
158
159 if (packet->descriptors()[0].type != IDataChunk::kType) {
160 *result_listener << "the first chunk in the packet is not an i-data chunk";
161 return false;
162 }
163
164 absl::optional<IDataChunk> dc =
165 IDataChunk::Parse(packet->descriptors()[0].data);
166 if (!dc.has_value()) {
167 *result_listener << "The first chunk didn't parse as an i-data chunk";
168 return false;
169 }
170
171 if (dc->message_id() != mid) {
172 *result_listener << "the mid is " << *dc->message_id();
173 return false;
174 }
175
176 return true;
177 }
178
179 MATCHER_P(HasSackWithCumAckTsn, tsn, "") {
180 absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
181 if (!packet.has_value()) {
182 *result_listener << "data didn't parse as an SctpPacket";
183 return false;
184 }
185
186 if (packet->descriptors()[0].type != SackChunk::kType) {
187 *result_listener << "the first chunk in the packet is not a data chunk";
188 return false;
189 }
190
191 absl::optional<SackChunk> sc =
192 SackChunk::Parse(packet->descriptors()[0].data);
193 if (!sc.has_value()) {
194 *result_listener << "The first chunk didn't parse as a data chunk";
195 return false;
196 }
197
198 if (sc->cumulative_tsn_ack() != tsn) {
199 *result_listener << "the cum_ack_tsn is " << *sc->cumulative_tsn_ack();
200 return false;
201 }
202
203 return true;
204 }
205
206 MATCHER(HasSackWithNoGapAckBlocks, "") {
207 absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
208 if (!packet.has_value()) {
209 *result_listener << "data didn't parse as an SctpPacket";
210 return false;
211 }
212
213 if (packet->descriptors()[0].type != SackChunk::kType) {
214 *result_listener << "the first chunk in the packet is not a data chunk";
215 return false;
216 }
217
218 absl::optional<SackChunk> sc =
219 SackChunk::Parse(packet->descriptors()[0].data);
220 if (!sc.has_value()) {
221 *result_listener << "The first chunk didn't parse as a data chunk";
222 return false;
223 }
224
225 if (!sc->gap_ack_blocks().empty()) {
226 *result_listener << "there are gap ack blocks";
227 return false;
228 }
229
230 return true;
231 }
232
233 MATCHER_P(HasReconfigWithStreams, streams_matcher, "") {
234 absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
235 if (!packet.has_value()) {
236 *result_listener << "data didn't parse as an SctpPacket";
237 return false;
238 }
239
240 if (packet->descriptors()[0].type != ReConfigChunk::kType) {
241 *result_listener << "the first chunk in the packet is not a data chunk";
242 return false;
243 }
244
245 absl::optional<ReConfigChunk> reconfig =
246 ReConfigChunk::Parse(packet->descriptors()[0].data);
247 if (!reconfig.has_value()) {
248 *result_listener << "The first chunk didn't parse as a data chunk";
249 return false;
250 }
251
252 const Parameters& parameters = reconfig->parameters();
253 if (parameters.descriptors().size() != 1 ||
254 parameters.descriptors()[0].type !=
255 OutgoingSSNResetRequestParameter::kType) {
256 *result_listener << "Expected the reconfig chunk to have an outgoing SSN "
257 "reset request parameter";
258 return false;
259 }
260
261 absl::optional<OutgoingSSNResetRequestParameter> p =
262 OutgoingSSNResetRequestParameter::Parse(parameters.descriptors()[0].data);
263 testing::Matcher<rtc::ArrayView<const StreamID>> matcher = streams_matcher;
264 if (!matcher.MatchAndExplain(p->stream_ids(), result_listener)) {
265 return false;
266 }
267
268 return true;
269 }
270
271 MATCHER_P(HasReconfigWithResponse, result, "") {
272 absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
273 if (!packet.has_value()) {
274 *result_listener << "data didn't parse as an SctpPacket";
275 return false;
276 }
277
278 if (packet->descriptors()[0].type != ReConfigChunk::kType) {
279 *result_listener << "the first chunk in the packet is not a reconfig chunk";
280 return false;
281 }
282
283 absl::optional<ReConfigChunk> reconfig =
284 ReConfigChunk::Parse(packet->descriptors()[0].data);
285 if (!reconfig.has_value()) {
286 *result_listener << "The first chunk didn't parse as a reconfig chunk";
287 return false;
288 }
289
290 const Parameters& parameters = reconfig->parameters();
291 if (parameters.descriptors().size() != 1 ||
292 parameters.descriptors()[0].type !=
293 ReconfigurationResponseParameter::kType) {
294 *result_listener << "Expected the reconfig chunk to have a "
295 "ReconfigurationResponse Parameter";
296 return false;
297 }
298
299 absl::optional<ReconfigurationResponseParameter> p =
300 ReconfigurationResponseParameter::Parse(parameters.descriptors()[0].data);
301 if (p->result() != result) {
302 *result_listener << "ReconfigurationResponse Parameter doesn't contain the "
303 "expected result";
304 return false;
305 }
306
307 return true;
308 }
309
AddTo(TSN tsn,int delta)310 TSN AddTo(TSN tsn, int delta) {
311 return TSN(*tsn + delta);
312 }
313
FixupOptions(DcSctpOptions options={})314 DcSctpOptions FixupOptions(DcSctpOptions options = {}) {
315 DcSctpOptions fixup = options;
316 // To make the interval more predictable in tests.
317 fixup.heartbeat_interval_include_rtt = false;
318 fixup.max_burst = kMaxBurstPackets;
319 return fixup;
320 }
321
GetPacketObserver(absl::string_view name)322 std::unique_ptr<PacketObserver> GetPacketObserver(absl::string_view name) {
323 if (absl::GetFlag(FLAGS_dcsctp_capture_packets)) {
324 return std::make_unique<TextPcapPacketObserver>(name);
325 }
326 return nullptr;
327 }
328
329 struct SocketUnderTest {
SocketUnderTestdcsctp::__anon2dc231170111::SocketUnderTest330 explicit SocketUnderTest(absl::string_view name,
331 const DcSctpOptions& opts = {})
332 : options(FixupOptions(opts)),
333 cb(name),
334 socket(name, cb, GetPacketObserver(name), options) {}
335
336 const DcSctpOptions options;
337 testing::NiceMock<MockDcSctpSocketCallbacks> cb;
338 DcSctpSocket socket;
339 };
340
ExchangeMessages(SocketUnderTest & a,SocketUnderTest & z)341 void ExchangeMessages(SocketUnderTest& a, SocketUnderTest& z) {
342 bool delivered_packet = false;
343 do {
344 delivered_packet = false;
345 std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
346 if (!packet_from_a.empty()) {
347 delivered_packet = true;
348 z.socket.ReceivePacket(std::move(packet_from_a));
349 }
350 std::vector<uint8_t> packet_from_z = z.cb.ConsumeSentPacket();
351 if (!packet_from_z.empty()) {
352 delivered_packet = true;
353 a.socket.ReceivePacket(std::move(packet_from_z));
354 }
355 } while (delivered_packet);
356 }
357
RunTimers(SocketUnderTest & s)358 void RunTimers(SocketUnderTest& s) {
359 for (;;) {
360 absl::optional<TimeoutID> timeout_id = s.cb.GetNextExpiredTimeout();
361 if (!timeout_id.has_value()) {
362 break;
363 }
364 s.socket.HandleTimeout(*timeout_id);
365 }
366 }
367
AdvanceTime(SocketUnderTest & a,SocketUnderTest & z,DurationMs duration)368 void AdvanceTime(SocketUnderTest& a, SocketUnderTest& z, DurationMs duration) {
369 a.cb.AdvanceTime(duration);
370 z.cb.AdvanceTime(duration);
371
372 RunTimers(a);
373 RunTimers(z);
374 }
375
376 // Calls Connect() on `sock_a_` and make the connection established.
ConnectSockets(SocketUnderTest & a,SocketUnderTest & z)377 void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) {
378 EXPECT_CALL(a.cb, OnConnected).Times(1);
379 EXPECT_CALL(z.cb, OnConnected).Times(1);
380
381 a.socket.Connect();
382 // Z reads INIT, INIT_ACK, COOKIE_ECHO, COOKIE_ACK
383 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
384 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
385 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
386 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
387
388 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
389 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
390 }
391
HandoverSocket(std::unique_ptr<SocketUnderTest> sut)392 std::unique_ptr<SocketUnderTest> HandoverSocket(
393 std::unique_ptr<SocketUnderTest> sut) {
394 EXPECT_EQ(sut->socket.GetHandoverReadiness(), HandoverReadinessStatus());
395
396 bool is_closed = sut->socket.state() == SocketState::kClosed;
397 if (!is_closed) {
398 EXPECT_CALL(sut->cb, OnClosed).Times(1);
399 }
400 absl::optional<DcSctpSocketHandoverState> handover_state =
401 sut->socket.GetHandoverStateAndClose();
402 EXPECT_TRUE(handover_state.has_value());
403 g_handover_state_transformer_for_test(&*handover_state);
404
405 auto handover_socket = std::make_unique<SocketUnderTest>("H", sut->options);
406 if (!is_closed) {
407 EXPECT_CALL(handover_socket->cb, OnConnected).Times(1);
408 }
409 handover_socket->socket.RestoreFromState(*handover_state);
410 return handover_socket;
411 }
412
GetReceivedMessagePpids(SocketUnderTest & z)413 std::vector<uint32_t> GetReceivedMessagePpids(SocketUnderTest& z) {
414 std::vector<uint32_t> ppids;
415 for (;;) {
416 absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
417 if (!msg.has_value()) {
418 break;
419 }
420 ppids.push_back(*msg->ppid());
421 }
422 return ppids;
423 }
424
425 // Test parameter that controls whether to perform handovers during the test. A
426 // test can have multiple points where it conditionally hands over socket Z.
427 // Either socket Z will be handed over at all those points or handed over never.
428 enum class HandoverMode {
429 kNoHandover,
430 kPerformHandovers,
431 };
432
433 class DcSctpSocketParametrizedTest
434 : public ::testing::Test,
435 public ::testing::WithParamInterface<HandoverMode> {
436 protected:
437 // Trigger handover for `sut` depending on the current test param.
MaybeHandoverSocket(std::unique_ptr<SocketUnderTest> sut)438 std::unique_ptr<SocketUnderTest> MaybeHandoverSocket(
439 std::unique_ptr<SocketUnderTest> sut) {
440 if (GetParam() == HandoverMode::kPerformHandovers) {
441 return HandoverSocket(std::move(sut));
442 }
443 return sut;
444 }
445
446 // Trigger handover for socket Z depending on the current test param.
447 // Then checks message passing to verify the handed over socket is functional.
MaybeHandoverSocketAndSendMessage(SocketUnderTest & a,std::unique_ptr<SocketUnderTest> z)448 void MaybeHandoverSocketAndSendMessage(SocketUnderTest& a,
449 std::unique_ptr<SocketUnderTest> z) {
450 if (GetParam() == HandoverMode::kPerformHandovers) {
451 z = HandoverSocket(std::move(z));
452 }
453
454 ExchangeMessages(a, *z);
455 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
456 ExchangeMessages(a, *z);
457
458 absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
459 ASSERT_TRUE(msg.has_value());
460 EXPECT_EQ(msg->stream_id(), StreamID(1));
461 }
462 };
463
464 INSTANTIATE_TEST_SUITE_P(Handovers,
465 DcSctpSocketParametrizedTest,
466 testing::Values(HandoverMode::kNoHandover,
467 HandoverMode::kPerformHandovers),
__anon2dc231170202(const auto& test_info) 468 [](const auto& test_info) {
469 return test_info.param ==
470 HandoverMode::kPerformHandovers
471 ? "WithHandovers"
472 : "NoHandover";
473 });
474
TEST(DcSctpSocketTest,EstablishConnection)475 TEST(DcSctpSocketTest, EstablishConnection) {
476 SocketUnderTest a("A");
477 SocketUnderTest z("Z");
478
479 EXPECT_CALL(a.cb, OnConnected).Times(1);
480 EXPECT_CALL(z.cb, OnConnected).Times(1);
481 EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
482 EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
483
484 a.socket.Connect();
485 // Z reads INIT, produces INIT_ACK
486 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
487 // A reads INIT_ACK, produces COOKIE_ECHO
488 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
489 // Z reads COOKIE_ECHO, produces COOKIE_ACK
490 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
491 // A reads COOKIE_ACK.
492 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
493
494 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
495 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
496 }
497
TEST(DcSctpSocketTest,EstablishConnectionWithSetupCollision)498 TEST(DcSctpSocketTest, EstablishConnectionWithSetupCollision) {
499 SocketUnderTest a("A");
500 SocketUnderTest z("Z");
501
502 EXPECT_CALL(a.cb, OnConnected).Times(1);
503 EXPECT_CALL(z.cb, OnConnected).Times(1);
504 EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
505 EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
506 a.socket.Connect();
507 z.socket.Connect();
508
509 ExchangeMessages(a, z);
510
511 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
512 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
513 }
514
TEST(DcSctpSocketTest,ShuttingDownWhileEstablishingConnection)515 TEST(DcSctpSocketTest, ShuttingDownWhileEstablishingConnection) {
516 SocketUnderTest a("A");
517 SocketUnderTest z("Z");
518
519 EXPECT_CALL(a.cb, OnConnected).Times(0);
520 EXPECT_CALL(z.cb, OnConnected).Times(1);
521 a.socket.Connect();
522
523 // Z reads INIT, produces INIT_ACK
524 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
525 // A reads INIT_ACK, produces COOKIE_ECHO
526 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
527 // Z reads COOKIE_ECHO, produces COOKIE_ACK
528 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
529 // Drop COOKIE_ACK, just to more easily verify shutdown protocol.
530 z.cb.ConsumeSentPacket();
531
532 // As Socket A has received INIT_ACK, it has a TCB and is connected, while
533 // Socket Z needs to receive COOKIE_ECHO to get there. Socket A still has
534 // timers running at this point.
535 EXPECT_EQ(a.socket.state(), SocketState::kConnecting);
536 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
537
538 // Socket A is now shut down, which should make it stop those timers.
539 a.socket.Shutdown();
540
541 EXPECT_CALL(a.cb, OnClosed).Times(1);
542 EXPECT_CALL(z.cb, OnClosed).Times(1);
543
544 // Z reads SHUTDOWN, produces SHUTDOWN_ACK
545 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
546 // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE
547 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
548 // Z reads SHUTDOWN_COMPLETE.
549 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
550
551 EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
552 EXPECT_TRUE(z.cb.ConsumeSentPacket().empty());
553
554 EXPECT_EQ(a.socket.state(), SocketState::kClosed);
555 EXPECT_EQ(z.socket.state(), SocketState::kClosed);
556 }
557
TEST(DcSctpSocketTest,EstablishSimultaneousConnection)558 TEST(DcSctpSocketTest, EstablishSimultaneousConnection) {
559 SocketUnderTest a("A");
560 SocketUnderTest z("Z");
561
562 EXPECT_CALL(a.cb, OnConnected).Times(1);
563 EXPECT_CALL(z.cb, OnConnected).Times(1);
564 EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
565 EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
566 a.socket.Connect();
567
568 // INIT isn't received by Z, as it wasn't ready yet.
569 a.cb.ConsumeSentPacket();
570
571 z.socket.Connect();
572
573 // A reads INIT, produces INIT_ACK
574 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
575
576 // Z reads INIT_ACK, sends COOKIE_ECHO
577 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
578
579 // A reads COOKIE_ECHO - establishes connection.
580 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
581
582 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
583
584 // Proceed with the remaining packets.
585 ExchangeMessages(a, z);
586
587 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
588 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
589 }
590
TEST(DcSctpSocketTest,EstablishConnectionLostCookieAck)591 TEST(DcSctpSocketTest, EstablishConnectionLostCookieAck) {
592 SocketUnderTest a("A");
593 SocketUnderTest z("Z");
594
595 EXPECT_CALL(a.cb, OnConnected).Times(1);
596 EXPECT_CALL(z.cb, OnConnected).Times(1);
597 EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
598 EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
599
600 a.socket.Connect();
601 // Z reads INIT, produces INIT_ACK
602 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
603 // A reads INIT_ACK, produces COOKIE_ECHO
604 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
605 // Z reads COOKIE_ECHO, produces COOKIE_ACK
606 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
607 // COOKIE_ACK is lost.
608 z.cb.ConsumeSentPacket();
609
610 EXPECT_EQ(a.socket.state(), SocketState::kConnecting);
611 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
612
613 // This will make A re-send the COOKIE_ECHO
614 AdvanceTime(a, z, DurationMs(a.options.t1_cookie_timeout));
615
616 // Z reads COOKIE_ECHO, produces COOKIE_ACK
617 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
618 // A reads COOKIE_ACK.
619 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
620
621 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
622 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
623 }
624
TEST(DcSctpSocketTest,ResendInitAndEstablishConnection)625 TEST(DcSctpSocketTest, ResendInitAndEstablishConnection) {
626 SocketUnderTest a("A");
627 SocketUnderTest z("Z");
628
629 a.socket.Connect();
630 // INIT is never received by Z.
631 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
632 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
633 EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType);
634
635 AdvanceTime(a, z, a.options.t1_init_timeout);
636
637 // Z reads INIT, produces INIT_ACK
638 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
639 // A reads INIT_ACK, produces COOKIE_ECHO
640 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
641 // Z reads COOKIE_ECHO, produces COOKIE_ACK
642 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
643 // A reads COOKIE_ACK.
644 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
645
646 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
647 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
648 }
649
TEST(DcSctpSocketTest,ResendingInitTooManyTimesAborts)650 TEST(DcSctpSocketTest, ResendingInitTooManyTimesAborts) {
651 SocketUnderTest a("A");
652 SocketUnderTest z("Z");
653
654 a.socket.Connect();
655
656 // INIT is never received by Z.
657 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
658 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
659 EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType);
660
661 for (int i = 0; i < *a.options.max_init_retransmits; ++i) {
662 AdvanceTime(a, z, a.options.t1_init_timeout * (1 << i));
663
664 // INIT is resent
665 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet,
666 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
667 EXPECT_EQ(resent_init_packet.descriptors()[0].type, InitChunk::kType);
668 }
669
670 // Another timeout, after the max init retransmits.
671 EXPECT_CALL(a.cb, OnAborted).Times(1);
672 AdvanceTime(
673 a, z, a.options.t1_init_timeout * (1 << *a.options.max_init_retransmits));
674
675 EXPECT_EQ(a.socket.state(), SocketState::kClosed);
676 }
677
TEST(DcSctpSocketTest,ResendCookieEchoAndEstablishConnection)678 TEST(DcSctpSocketTest, ResendCookieEchoAndEstablishConnection) {
679 SocketUnderTest a("A");
680 SocketUnderTest z("Z");
681
682 a.socket.Connect();
683
684 // Z reads INIT, produces INIT_ACK
685 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
686 // A reads INIT_ACK, produces COOKIE_ECHO
687 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
688
689 // COOKIE_ECHO is never received by Z.
690 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
691 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
692 EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType);
693
694 AdvanceTime(a, z, a.options.t1_init_timeout);
695
696 // Z reads COOKIE_ECHO, produces COOKIE_ACK
697 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
698 // A reads COOKIE_ACK.
699 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
700
701 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
702 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
703 }
704
TEST(DcSctpSocketTest,ResendingCookieEchoTooManyTimesAborts)705 TEST(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) {
706 SocketUnderTest a("A");
707 SocketUnderTest z("Z");
708
709 a.socket.Connect();
710
711 // Z reads INIT, produces INIT_ACK
712 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
713 // A reads INIT_ACK, produces COOKIE_ECHO
714 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
715
716 // COOKIE_ECHO is never received by Z.
717 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
718 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
719 EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType);
720
721 for (int i = 0; i < *a.options.max_init_retransmits; ++i) {
722 AdvanceTime(a, z, a.options.t1_cookie_timeout * (1 << i));
723
724 // COOKIE_ECHO is resent
725 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet,
726 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
727 EXPECT_EQ(resent_init_packet.descriptors()[0].type, CookieEchoChunk::kType);
728 }
729
730 // Another timeout, after the max init retransmits.
731 EXPECT_CALL(a.cb, OnAborted).Times(1);
732 AdvanceTime(
733 a, z,
734 a.options.t1_cookie_timeout * (1 << *a.options.max_init_retransmits));
735
736 EXPECT_EQ(a.socket.state(), SocketState::kClosed);
737 }
738
TEST(DcSctpSocketTest,DoesntSendMorePacketsUntilCookieAckHasBeenReceived)739 TEST(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) {
740 SocketUnderTest a("A");
741 SocketUnderTest z("Z");
742
743 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
744 std::vector<uint8_t>(kLargeMessageSize)),
745 kSendOptions);
746 a.socket.Connect();
747
748 // Z reads INIT, produces INIT_ACK
749 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
750 // A reads INIT_ACK, produces COOKIE_ECHO
751 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
752
753 // COOKIE_ECHO is never received by Z.
754 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet1,
755 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
756 EXPECT_THAT(cookie_echo_packet1.descriptors(), SizeIs(2));
757 EXPECT_EQ(cookie_echo_packet1.descriptors()[0].type, CookieEchoChunk::kType);
758 EXPECT_EQ(cookie_echo_packet1.descriptors()[1].type, DataChunk::kType);
759
760 EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
761
762 // There are DATA chunks in the sent packet (that was lost), which means that
763 // the T3-RTX timer is running, but as the socket is in kCookieEcho state, it
764 // will be T1-COOKIE that drives retransmissions, so when the T3-RTX expires,
765 // nothing should be retransmitted.
766 ASSERT_TRUE(a.options.rto_initial < a.options.t1_cookie_timeout);
767 AdvanceTime(a, z, a.options.rto_initial);
768 EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
769
770 // When T1-COOKIE expires, both the COOKIE-ECHO and DATA should be present.
771 AdvanceTime(a, z, a.options.t1_cookie_timeout - a.options.rto_initial);
772
773 // And this COOKIE-ECHO and DATA is also lost - never received by Z.
774 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet2,
775 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
776 EXPECT_THAT(cookie_echo_packet2.descriptors(), SizeIs(2));
777 EXPECT_EQ(cookie_echo_packet2.descriptors()[0].type, CookieEchoChunk::kType);
778 EXPECT_EQ(cookie_echo_packet2.descriptors()[1].type, DataChunk::kType);
779
780 EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
781
782 // COOKIE_ECHO has exponential backoff.
783 AdvanceTime(a, z, a.options.t1_cookie_timeout * 2);
784
785 // Z reads COOKIE_ECHO, produces COOKIE_ACK
786 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
787 // A reads COOKIE_ACK.
788 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
789
790 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
791 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
792
793 ExchangeMessages(a, z);
794 EXPECT_THAT(z.cb.ConsumeReceivedMessage()->payload(),
795 SizeIs(kLargeMessageSize));
796 }
797
TEST_P(DcSctpSocketParametrizedTest,ShutdownConnection)798 TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) {
799 SocketUnderTest a("A");
800 auto z = std::make_unique<SocketUnderTest>("Z");
801
802 ConnectSockets(a, *z);
803 z = MaybeHandoverSocket(std::move(z));
804
805 RTC_LOG(LS_INFO) << "Shutting down";
806
807 EXPECT_CALL(z->cb, OnClosed).Times(1);
808 a.socket.Shutdown();
809 // Z reads SHUTDOWN, produces SHUTDOWN_ACK
810 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
811 // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE
812 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
813 // Z reads SHUTDOWN_COMPLETE.
814 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
815
816 EXPECT_EQ(a.socket.state(), SocketState::kClosed);
817 EXPECT_EQ(z->socket.state(), SocketState::kClosed);
818
819 z = MaybeHandoverSocket(std::move(z));
820 EXPECT_EQ(z->socket.state(), SocketState::kClosed);
821 }
822
TEST(DcSctpSocketTest,ShutdownTimerExpiresTooManyTimeClosesConnection)823 TEST(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) {
824 SocketUnderTest a("A");
825 SocketUnderTest z("Z");
826
827 ConnectSockets(a, z);
828
829 a.socket.Shutdown();
830 // Drop first SHUTDOWN packet.
831 a.cb.ConsumeSentPacket();
832
833 EXPECT_EQ(a.socket.state(), SocketState::kShuttingDown);
834
835 for (int i = 0; i < *a.options.max_retransmissions; ++i) {
836 AdvanceTime(a, z, DurationMs(a.options.rto_initial * (1 << i)));
837
838 // Dropping every shutdown chunk.
839 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
840 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
841 EXPECT_EQ(packet.descriptors()[0].type, ShutdownChunk::kType);
842 EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
843 }
844 // The last expiry, makes it abort the connection.
845 EXPECT_CALL(a.cb, OnAborted).Times(1);
846 AdvanceTime(a, z,
847 a.options.rto_initial * (1 << *a.options.max_retransmissions));
848
849 EXPECT_EQ(a.socket.state(), SocketState::kClosed);
850 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
851 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
852 EXPECT_EQ(packet.descriptors()[0].type, AbortChunk::kType);
853 EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
854 }
855
TEST(DcSctpSocketTest,EstablishConnectionWhileSendingData)856 TEST(DcSctpSocketTest, EstablishConnectionWhileSendingData) {
857 SocketUnderTest a("A");
858 SocketUnderTest z("Z");
859
860 a.socket.Connect();
861
862 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
863
864 // Z reads INIT, produces INIT_ACK
865 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
866 // // A reads INIT_ACK, produces COOKIE_ECHO
867 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
868 // // Z reads COOKIE_ECHO, produces COOKIE_ACK
869 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
870 // // A reads COOKIE_ACK.
871 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
872
873 EXPECT_EQ(a.socket.state(), SocketState::kConnected);
874 EXPECT_EQ(z.socket.state(), SocketState::kConnected);
875
876 absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
877 ASSERT_TRUE(msg.has_value());
878 EXPECT_EQ(msg->stream_id(), StreamID(1));
879 }
880
TEST(DcSctpSocketTest,SendMessageAfterEstablished)881 TEST(DcSctpSocketTest, SendMessageAfterEstablished) {
882 SocketUnderTest a("A");
883 SocketUnderTest z("Z");
884
885 ConnectSockets(a, z);
886
887 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
888 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
889
890 absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
891 ASSERT_TRUE(msg.has_value());
892 EXPECT_EQ(msg->stream_id(), StreamID(1));
893 }
894
TEST_P(DcSctpSocketParametrizedTest,TimeoutResendsPacket)895 TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) {
896 SocketUnderTest a("A");
897 auto z = std::make_unique<SocketUnderTest>("Z");
898
899 ConnectSockets(a, *z);
900 z = MaybeHandoverSocket(std::move(z));
901
902 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
903 a.cb.ConsumeSentPacket();
904
905 RTC_LOG(LS_INFO) << "Advancing time";
906 AdvanceTime(a, *z, a.options.rto_initial);
907
908 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
909
910 absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
911 ASSERT_TRUE(msg.has_value());
912 EXPECT_EQ(msg->stream_id(), StreamID(1));
913
914 MaybeHandoverSocketAndSendMessage(a, std::move(z));
915 }
916
TEST_P(DcSctpSocketParametrizedTest,SendALotOfBytesMissedSecondPacket)917 TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) {
918 SocketUnderTest a("A");
919 auto z = std::make_unique<SocketUnderTest>("Z");
920
921 ConnectSockets(a, *z);
922 z = MaybeHandoverSocket(std::move(z));
923
924 std::vector<uint8_t> payload(kLargeMessageSize);
925 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
926
927 // First DATA
928 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
929 // Second DATA (lost)
930 a.cb.ConsumeSentPacket();
931
932 // Retransmit and handle the rest
933 ExchangeMessages(a, *z);
934
935 absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
936 ASSERT_TRUE(msg.has_value());
937 EXPECT_EQ(msg->stream_id(), StreamID(1));
938 EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
939
940 MaybeHandoverSocketAndSendMessage(a, std::move(z));
941 }
942
TEST_P(DcSctpSocketParametrizedTest,SendingHeartbeatAnswersWithAck)943 TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) {
944 SocketUnderTest a("A");
945 auto z = std::make_unique<SocketUnderTest>("Z");
946
947 ConnectSockets(a, *z);
948 z = MaybeHandoverSocket(std::move(z));
949
950 // Inject a HEARTBEAT chunk
951 SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
952 uint8_t info[] = {1, 2, 3, 4};
953 Parameters::Builder params_builder;
954 params_builder.Add(HeartbeatInfoParameter(info));
955 b.Add(HeartbeatRequestChunk(params_builder.Build()));
956 a.socket.ReceivePacket(b.Build());
957
958 // HEARTBEAT_ACK is sent as a reply. Capture it.
959 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket ack_packet,
960 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
961 ASSERT_THAT(ack_packet.descriptors(), SizeIs(1));
962 ASSERT_HAS_VALUE_AND_ASSIGN(
963 HeartbeatAckChunk ack,
964 HeartbeatAckChunk::Parse(ack_packet.descriptors()[0].data));
965 ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info());
966 EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4));
967
968 MaybeHandoverSocketAndSendMessage(a, std::move(z));
969 }
970
TEST_P(DcSctpSocketParametrizedTest,ExpectHeartbeatToBeSent)971 TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) {
972 SocketUnderTest a("A");
973 auto z = std::make_unique<SocketUnderTest>("Z");
974
975 ConnectSockets(a, *z);
976 z = MaybeHandoverSocket(std::move(z));
977
978 EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
979
980 AdvanceTime(a, *z, a.options.heartbeat_interval);
981
982 std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket();
983 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
984 SctpPacket::Parse(hb_packet_raw));
985 ASSERT_THAT(hb_packet.descriptors(), SizeIs(1));
986 ASSERT_HAS_VALUE_AND_ASSIGN(
987 HeartbeatRequestChunk hb,
988 HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data));
989 ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, hb.info());
990
991 // The info is a single 64-bit number.
992 EXPECT_THAT(hb.info()->info(), SizeIs(8));
993
994 // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back.
995 z->socket.ReceivePacket(hb_packet_raw);
996 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
997
998 MaybeHandoverSocketAndSendMessage(a, std::move(z));
999 }
1000
TEST_P(DcSctpSocketParametrizedTest,CloseConnectionAfterTooManyLostHeartbeats)1001 TEST_P(DcSctpSocketParametrizedTest,
1002 CloseConnectionAfterTooManyLostHeartbeats) {
1003 SocketUnderTest a("A");
1004 auto z = std::make_unique<SocketUnderTest>("Z");
1005
1006 ConnectSockets(a, *z);
1007 z = MaybeHandoverSocket(std::move(z));
1008
1009 EXPECT_CALL(z->cb, OnClosed).Times(1);
1010 EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty());
1011 // Force-close socket Z so that it doesn't interfere from now on.
1012 z->socket.Close();
1013
1014 DurationMs time_to_next_hearbeat = a.options.heartbeat_interval;
1015
1016 for (int i = 0; i < *a.options.max_retransmissions; ++i) {
1017 RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending...";
1018 AdvanceTime(a, *z, time_to_next_hearbeat);
1019
1020 // Dropping every heartbeat.
1021 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
1022 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
1023 EXPECT_EQ(hb_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
1024
1025 RTC_LOG(LS_INFO) << "Letting the heartbeat expire.";
1026 AdvanceTime(a, *z, DurationMs(1000));
1027
1028 time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000);
1029 }
1030
1031 RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending...";
1032 AdvanceTime(a, *z, time_to_next_hearbeat);
1033
1034 // Last heartbeat
1035 EXPECT_THAT(a.cb.ConsumeSentPacket(), Not(IsEmpty()));
1036
1037 EXPECT_CALL(a.cb, OnAborted).Times(1);
1038 // Should suffice as exceeding RTO
1039 AdvanceTime(a, *z, DurationMs(1000));
1040
1041 z = MaybeHandoverSocket(std::move(z));
1042 }
1043
TEST_P(DcSctpSocketParametrizedTest,RecoversAfterASuccessfulAck)1044 TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) {
1045 SocketUnderTest a("A");
1046 auto z = std::make_unique<SocketUnderTest>("Z");
1047
1048 ConnectSockets(a, *z);
1049 z = MaybeHandoverSocket(std::move(z));
1050
1051 EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty());
1052 EXPECT_CALL(z->cb, OnClosed).Times(1);
1053 // Force-close socket Z so that it doesn't interfere from now on.
1054 z->socket.Close();
1055
1056 DurationMs time_to_next_hearbeat = a.options.heartbeat_interval;
1057
1058 for (int i = 0; i < *a.options.max_retransmissions; ++i) {
1059 AdvanceTime(a, *z, time_to_next_hearbeat);
1060
1061 // Dropping every heartbeat.
1062 a.cb.ConsumeSentPacket();
1063
1064 RTC_LOG(LS_INFO) << "Letting the heartbeat expire.";
1065 AdvanceTime(a, *z, DurationMs(1000));
1066
1067 time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000);
1068 }
1069
1070 RTC_LOG(LS_INFO) << "Getting the last heartbeat - and acking it";
1071 AdvanceTime(a, *z, time_to_next_hearbeat);
1072
1073 std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket();
1074 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
1075 SctpPacket::Parse(hb_packet_raw));
1076 ASSERT_THAT(hb_packet.descriptors(), SizeIs(1));
1077 ASSERT_HAS_VALUE_AND_ASSIGN(
1078 HeartbeatRequestChunk hb,
1079 HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data));
1080
1081 SctpPacket::Builder b(a.socket.verification_tag(), a.options);
1082 b.Add(HeartbeatAckChunk(std::move(hb).extract_parameters()));
1083 a.socket.ReceivePacket(b.Build());
1084
1085 // Should suffice as exceeding RTO - which will not fire.
1086 EXPECT_CALL(a.cb, OnAborted).Times(0);
1087 AdvanceTime(a, *z, DurationMs(1000));
1088
1089 EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
1090
1091 // Verify that we get new heartbeats again.
1092 RTC_LOG(LS_INFO) << "Expecting a new heartbeat";
1093 AdvanceTime(a, *z, time_to_next_hearbeat);
1094
1095 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket another_packet,
1096 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
1097 EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
1098 }
1099
TEST_P(DcSctpSocketParametrizedTest,ResetStream)1100 TEST_P(DcSctpSocketParametrizedTest, ResetStream) {
1101 SocketUnderTest a("A");
1102 auto z = std::make_unique<SocketUnderTest>("Z");
1103
1104 ConnectSockets(a, *z);
1105 z = MaybeHandoverSocket(std::move(z));
1106
1107 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {});
1108 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1109
1110 absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
1111 ASSERT_TRUE(msg.has_value());
1112 EXPECT_EQ(msg->stream_id(), StreamID(1));
1113
1114 // Handle SACK
1115 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1116
1117 // Reset the outgoing stream. This will directly send a RE-CONFIG.
1118 a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
1119
1120 // Receiving the packet will trigger a callback, indicating that A has
1121 // reset its stream. It will also send a RE-CONFIG with a response.
1122 EXPECT_CALL(z->cb, OnIncomingStreamsReset).Times(1);
1123 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1124
1125 // Receiving a response will trigger a callback. Streams are now reset.
1126 EXPECT_CALL(a.cb, OnStreamsResetPerformed).Times(1);
1127 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1128
1129 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1130 }
1131
TEST_P(DcSctpSocketParametrizedTest,ResetStreamWillMakeChunksStartAtZeroSsn)1132 TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
1133 SocketUnderTest a("A");
1134 auto z = std::make_unique<SocketUnderTest>("Z");
1135
1136 ConnectSockets(a, *z);
1137 z = MaybeHandoverSocket(std::move(z));
1138
1139 std::vector<uint8_t> payload(a.options.mtu - 100);
1140
1141 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1142 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1143
1144 auto packet1 = a.cb.ConsumeSentPacket();
1145 EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0)));
1146 z->socket.ReceivePacket(packet1);
1147
1148 auto packet2 = a.cb.ConsumeSentPacket();
1149 EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1)));
1150 z->socket.ReceivePacket(packet2);
1151
1152 // Handle SACK
1153 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1154
1155 absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
1156 ASSERT_TRUE(msg1.has_value());
1157 EXPECT_EQ(msg1->stream_id(), StreamID(1));
1158
1159 absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
1160 ASSERT_TRUE(msg2.has_value());
1161 EXPECT_EQ(msg2->stream_id(), StreamID(1));
1162
1163 // Reset the outgoing stream. This will directly send a RE-CONFIG.
1164 a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
1165 // RE-CONFIG, req
1166 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1167 // RE-CONFIG, resp
1168 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1169
1170 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1171
1172 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1173
1174 auto packet3 = a.cb.ConsumeSentPacket();
1175 EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0)));
1176 z->socket.ReceivePacket(packet3);
1177
1178 auto packet4 = a.cb.ConsumeSentPacket();
1179 EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1)));
1180 z->socket.ReceivePacket(packet4);
1181
1182 // Handle SACK
1183 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1184
1185 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1186 }
1187
TEST_P(DcSctpSocketParametrizedTest,ResetStreamWillOnlyResetTheRequestedStreams)1188 TEST_P(DcSctpSocketParametrizedTest,
1189 ResetStreamWillOnlyResetTheRequestedStreams) {
1190 SocketUnderTest a("A");
1191 auto z = std::make_unique<SocketUnderTest>("Z");
1192
1193 ConnectSockets(a, *z);
1194 z = MaybeHandoverSocket(std::move(z));
1195
1196 std::vector<uint8_t> payload(a.options.mtu - 100);
1197
1198 // Send two ordered messages on SID 1
1199 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1200 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1201
1202 auto packet1 = a.cb.ConsumeSentPacket();
1203 EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1)));
1204 EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0)));
1205 z->socket.ReceivePacket(packet1);
1206
1207 auto packet2 = a.cb.ConsumeSentPacket();
1208 EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1)));
1209 EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1)));
1210 z->socket.ReceivePacket(packet2);
1211
1212 // Handle SACK
1213 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1214
1215 // Do the same, for SID 3
1216 a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
1217 a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
1218 auto packet3 = a.cb.ConsumeSentPacket();
1219 EXPECT_THAT(packet3, HasDataChunkWithStreamId(StreamID(3)));
1220 EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0)));
1221 z->socket.ReceivePacket(packet3);
1222 auto packet4 = a.cb.ConsumeSentPacket();
1223 EXPECT_THAT(packet4, HasDataChunkWithStreamId(StreamID(3)));
1224 EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1)));
1225 z->socket.ReceivePacket(packet4);
1226 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1227
1228 // Receive all messages.
1229 absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
1230 ASSERT_TRUE(msg1.has_value());
1231 EXPECT_EQ(msg1->stream_id(), StreamID(1));
1232
1233 absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
1234 ASSERT_TRUE(msg2.has_value());
1235 EXPECT_EQ(msg2->stream_id(), StreamID(1));
1236
1237 absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage();
1238 ASSERT_TRUE(msg3.has_value());
1239 EXPECT_EQ(msg3->stream_id(), StreamID(3));
1240
1241 absl::optional<DcSctpMessage> msg4 = z->cb.ConsumeReceivedMessage();
1242 ASSERT_TRUE(msg4.has_value());
1243 EXPECT_EQ(msg4->stream_id(), StreamID(3));
1244
1245 // Reset SID 1. This will directly send a RE-CONFIG.
1246 a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
1247 // RE-CONFIG, req
1248 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1249 // RE-CONFIG, resp
1250 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1251
1252 // Send a message on SID 1 and 3 - SID 1 should not be reset, but 3 should.
1253 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1254
1255 a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
1256
1257 auto packet5 = a.cb.ConsumeSentPacket();
1258 EXPECT_THAT(packet5, HasDataChunkWithStreamId(StreamID(1)));
1259 EXPECT_THAT(packet5, HasDataChunkWithSsn(SSN(2))); // Unchanged.
1260 z->socket.ReceivePacket(packet5);
1261
1262 auto packet6 = a.cb.ConsumeSentPacket();
1263 EXPECT_THAT(packet6, HasDataChunkWithStreamId(StreamID(3)));
1264 EXPECT_THAT(packet6, HasDataChunkWithSsn(SSN(0))); // Reset.
1265 z->socket.ReceivePacket(packet6);
1266
1267 // Handle SACK
1268 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1269
1270 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1271 }
1272
TEST_P(DcSctpSocketParametrizedTest,OnePeerReconnects)1273 TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) {
1274 SocketUnderTest a("A");
1275 auto z = std::make_unique<SocketUnderTest>("Z");
1276
1277 ConnectSockets(a, *z);
1278 z = MaybeHandoverSocket(std::move(z));
1279
1280 EXPECT_CALL(a.cb, OnConnectionRestarted).Times(1);
1281 // Let's be evil here - reconnect while a fragmented packet was about to be
1282 // sent. The receiving side should get it in full.
1283 std::vector<uint8_t> payload(kLargeMessageSize);
1284 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
1285
1286 // First DATA
1287 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1288
1289 // Create a new association, z2 - and don't use z anymore.
1290 SocketUnderTest z2("Z2");
1291 z2.socket.Connect();
1292
1293 // Retransmit and handle the rest. As there will be some chunks in-flight that
1294 // have the wrong verification tag, those will yield errors.
1295 ExchangeMessages(a, z2);
1296
1297 absl::optional<DcSctpMessage> msg = z2.cb.ConsumeReceivedMessage();
1298 ASSERT_TRUE(msg.has_value());
1299 EXPECT_EQ(msg->stream_id(), StreamID(1));
1300 EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
1301 }
1302
TEST_P(DcSctpSocketParametrizedTest,SendMessageWithLimitedRtx)1303 TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) {
1304 SocketUnderTest a("A");
1305 auto z = std::make_unique<SocketUnderTest>("Z");
1306
1307 ConnectSockets(a, *z);
1308 z = MaybeHandoverSocket(std::move(z));
1309
1310 SendOptions send_options;
1311 send_options.max_retransmissions = 0;
1312 std::vector<uint8_t> payload(a.options.mtu - 100);
1313 a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
1314 a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
1315 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options);
1316
1317 // First DATA
1318 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1319 // Second DATA (lost)
1320 a.cb.ConsumeSentPacket();
1321 // Third DATA
1322 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1323
1324 // Handle SACK for first DATA
1325 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1326
1327 // Handle delayed SACK for third DATA
1328 AdvanceTime(a, *z, a.options.delayed_ack_max_timeout);
1329
1330 // Handle SACK for second DATA
1331 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1332
1333 // Now the missing data chunk will be marked as nacked, but it might still be
1334 // in-flight and the reported gap could be due to out-of-order delivery. So
1335 // the RetransmissionQueue will not mark it as "to be retransmitted" until
1336 // after the t3-rtx timer has expired.
1337 AdvanceTime(a, *z, a.options.rto_initial);
1338
1339 // The chunk will be marked as retransmitted, and then as abandoned, which
1340 // will trigger a FORWARD-TSN to be sent.
1341
1342 // FORWARD-TSN (third)
1343 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1344
1345 // Which will trigger a SACK
1346 a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1347
1348 absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
1349 ASSERT_TRUE(msg1.has_value());
1350 EXPECT_EQ(msg1->ppid(), PPID(51));
1351
1352 absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
1353 ASSERT_TRUE(msg2.has_value());
1354 EXPECT_EQ(msg2->ppid(), PPID(53));
1355
1356 absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage();
1357 EXPECT_FALSE(msg3.has_value());
1358
1359 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1360 }
1361
TEST_P(DcSctpSocketParametrizedTest,SendManyFragmentedMessagesWithLimitedRtx)1362 TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) {
1363 SocketUnderTest a("A");
1364 auto z = std::make_unique<SocketUnderTest>("Z");
1365
1366 ConnectSockets(a, *z);
1367 z = MaybeHandoverSocket(std::move(z));
1368
1369 SendOptions send_options;
1370 send_options.unordered = IsUnordered(true);
1371 send_options.max_retransmissions = 0;
1372 std::vector<uint8_t> payload(a.options.mtu * 2 - 100 /* margin */);
1373 // Sending first message
1374 a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
1375 // Sending second message
1376 a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
1377 // Sending third message
1378 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options);
1379 // Sending fourth message
1380 a.socket.Send(DcSctpMessage(StreamID(1), PPID(54), payload), send_options);
1381
1382 // First DATA, first fragment
1383 std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
1384 EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51)));
1385 z->socket.ReceivePacket(std::move(packet));
1386
1387 // First DATA, second fragment (lost)
1388 packet = a.cb.ConsumeSentPacket();
1389 EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51)));
1390
1391 // Second DATA, first fragment
1392 packet = a.cb.ConsumeSentPacket();
1393 EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52)));
1394 z->socket.ReceivePacket(std::move(packet));
1395
1396 // Second DATA, second fragment (lost)
1397 packet = a.cb.ConsumeSentPacket();
1398 EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52)));
1399 EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1400
1401 // Third DATA, first fragment
1402 packet = a.cb.ConsumeSentPacket();
1403 EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53)));
1404 EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1405 z->socket.ReceivePacket(std::move(packet));
1406
1407 // Third DATA, second fragment (lost)
1408 packet = a.cb.ConsumeSentPacket();
1409 EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53)));
1410 EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1411
1412 // Fourth DATA, first fragment
1413 packet = a.cb.ConsumeSentPacket();
1414 EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54)));
1415 EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1416 z->socket.ReceivePacket(std::move(packet));
1417
1418 // Fourth DATA, second fragment
1419 packet = a.cb.ConsumeSentPacket();
1420 EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54)));
1421 EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1422 z->socket.ReceivePacket(std::move(packet));
1423
1424 ExchangeMessages(a, *z);
1425
1426 // Let the RTX timer expire, and exchange FORWARD-TSN/SACKs
1427 AdvanceTime(a, *z, a.options.rto_initial);
1428
1429 ExchangeMessages(a, *z);
1430
1431 absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
1432 ASSERT_TRUE(msg1.has_value());
1433 EXPECT_EQ(msg1->ppid(), PPID(54));
1434
1435 ASSERT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
1436
1437 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1438 }
1439
1440 struct FakeChunkConfig : ChunkConfig {
1441 static constexpr int kType = 0x49;
1442 static constexpr size_t kHeaderSize = 4;
1443 static constexpr int kVariableLengthAlignment = 0;
1444 };
1445
1446 class FakeChunk : public Chunk, public TLVTrait<FakeChunkConfig> {
1447 public:
FakeChunk()1448 FakeChunk() {}
1449
1450 FakeChunk(FakeChunk&& other) = default;
1451 FakeChunk& operator=(FakeChunk&& other) = default;
1452
SerializeTo(std::vector<uint8_t> & out) const1453 void SerializeTo(std::vector<uint8_t>& out) const override {
1454 AllocateTLV(out);
1455 }
ToString() const1456 std::string ToString() const override { return "FAKE"; }
1457 };
1458
TEST_P(DcSctpSocketParametrizedTest,ReceivingUnknownChunkRespondsWithError)1459 TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) {
1460 SocketUnderTest a("A");
1461 auto z = std::make_unique<SocketUnderTest>("Z");
1462
1463 ConnectSockets(a, *z);
1464 z = MaybeHandoverSocket(std::move(z));
1465
1466 // Inject a FAKE chunk
1467 SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
1468 b.Add(FakeChunk());
1469 a.socket.ReceivePacket(b.Build());
1470
1471 // ERROR is sent as a reply. Capture it.
1472 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket reply_packet,
1473 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
1474 ASSERT_THAT(reply_packet.descriptors(), SizeIs(1));
1475 ASSERT_HAS_VALUE_AND_ASSIGN(
1476 ErrorChunk error, ErrorChunk::Parse(reply_packet.descriptors()[0].data));
1477 ASSERT_HAS_VALUE_AND_ASSIGN(
1478 UnrecognizedChunkTypeCause cause,
1479 error.error_causes().get<UnrecognizedChunkTypeCause>());
1480 EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04));
1481
1482 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1483 }
1484
TEST_P(DcSctpSocketParametrizedTest,ReceivingErrorChunkReportsAsCallback)1485 TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) {
1486 SocketUnderTest a("A");
1487 auto z = std::make_unique<SocketUnderTest>("Z");
1488
1489 ConnectSockets(a, *z);
1490 z = MaybeHandoverSocket(std::move(z));
1491
1492 // Inject a ERROR chunk
1493 SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
1494 b.Add(
1495 ErrorChunk(Parameters::Builder()
1496 .Add(UnrecognizedChunkTypeCause({0x49, 0x00, 0x00, 0x04}))
1497 .Build()));
1498
1499 EXPECT_CALL(a.cb, OnError(ErrorKind::kPeerReported,
1500 HasSubstr("Unrecognized Chunk Type")));
1501 a.socket.ReceivePacket(b.Build());
1502
1503 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1504 }
1505
TEST(DcSctpSocketTest,PassingHighWatermarkWillOnlyAcceptCumAckTsn)1506 TEST(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) {
1507 SocketUnderTest a("A");
1508
1509 constexpr size_t kReceiveWindowBufferSize = 2000;
1510 SocketUnderTest z(
1511 "Z", {.mtu = 3000,
1512 .max_receiver_window_buffer_size = kReceiveWindowBufferSize});
1513
1514 EXPECT_CALL(z.cb, OnClosed).Times(0);
1515 EXPECT_CALL(z.cb, OnAborted).Times(0);
1516
1517 a.socket.Connect();
1518 std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket();
1519 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
1520 SctpPacket::Parse(init_data));
1521 ASSERT_HAS_VALUE_AND_ASSIGN(
1522 InitChunk init_chunk,
1523 InitChunk::Parse(init_packet.descriptors()[0].data));
1524 z.socket.ReceivePacket(init_data);
1525 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
1526 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
1527 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
1528
1529 // Fill up Z2 to the high watermark limit.
1530 constexpr size_t kWatermarkLimit =
1531 kReceiveWindowBufferSize * ReassemblyQueue::kHighWatermarkLimit;
1532 constexpr size_t kRemainingSize = kReceiveWindowBufferSize - kWatermarkLimit;
1533
1534 TSN tsn = init_chunk.initial_tsn();
1535 AnyDataChunk::Options opts;
1536 opts.is_beginning = Data::IsBeginning(true);
1537 z.socket.ReceivePacket(
1538 SctpPacket::Builder(z.socket.verification_tag(), z.options)
1539 .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53),
1540 std::vector<uint8_t>(kWatermarkLimit + 1), opts))
1541 .Build());
1542
1543 // First DATA will always trigger a SACK. It's not interesting.
1544 EXPECT_THAT(z.cb.ConsumeSentPacket(),
1545 AllOf(HasSackWithCumAckTsn(tsn), HasSackWithNoGapAckBlocks()));
1546
1547 // This DATA should be accepted - it's advancing cum ack tsn.
1548 z.socket.ReceivePacket(
1549 SctpPacket::Builder(z.socket.verification_tag(), z.options)
1550 .Add(DataChunk(AddTo(tsn, 1), StreamID(1), SSN(0), PPID(53),
1551 std::vector<uint8_t>(1),
1552 /*options=*/{}))
1553 .Build());
1554
1555 // The receiver might have moved into delayed ack mode.
1556 AdvanceTime(a, z, z.options.rto_initial);
1557
1558 EXPECT_THAT(
1559 z.cb.ConsumeSentPacket(),
1560 AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks()));
1561
1562 // This DATA will not be accepted - it's not advancing cum ack tsn.
1563 z.socket.ReceivePacket(
1564 SctpPacket::Builder(z.socket.verification_tag(), z.options)
1565 .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53),
1566 std::vector<uint8_t>(1),
1567 /*options=*/{}))
1568 .Build());
1569
1570 // Sack will be sent in IMMEDIATE mode when this is happening.
1571 EXPECT_THAT(
1572 z.cb.ConsumeSentPacket(),
1573 AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks()));
1574
1575 // This DATA will not be accepted either.
1576 z.socket.ReceivePacket(
1577 SctpPacket::Builder(z.socket.verification_tag(), z.options)
1578 .Add(DataChunk(AddTo(tsn, 4), StreamID(1), SSN(0), PPID(53),
1579 std::vector<uint8_t>(1),
1580 /*options=*/{}))
1581 .Build());
1582
1583 // Sack will be sent in IMMEDIATE mode when this is happening.
1584 EXPECT_THAT(
1585 z.cb.ConsumeSentPacket(),
1586 AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks()));
1587
1588 // This DATA should be accepted, and it fills the reassembly queue.
1589 z.socket.ReceivePacket(
1590 SctpPacket::Builder(z.socket.verification_tag(), z.options)
1591 .Add(DataChunk(AddTo(tsn, 2), StreamID(1), SSN(0), PPID(53),
1592 std::vector<uint8_t>(kRemainingSize),
1593 /*options=*/{}))
1594 .Build());
1595
1596 // The receiver might have moved into delayed ack mode.
1597 AdvanceTime(a, z, z.options.rto_initial);
1598
1599 EXPECT_THAT(
1600 z.cb.ConsumeSentPacket(),
1601 AllOf(HasSackWithCumAckTsn(AddTo(tsn, 2)), HasSackWithNoGapAckBlocks()));
1602
1603 EXPECT_CALL(z.cb, OnAborted(ErrorKind::kResourceExhaustion, _));
1604 EXPECT_CALL(z.cb, OnClosed).Times(0);
1605
1606 // This DATA will make the connection close. It's too full now.
1607 z.socket.ReceivePacket(
1608 SctpPacket::Builder(z.socket.verification_tag(), z.options)
1609 .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53),
1610 std::vector<uint8_t>(kSmallMessageSize),
1611 /*options=*/{}))
1612 .Build());
1613 }
1614
TEST(DcSctpSocketTest,SetMaxMessageSize)1615 TEST(DcSctpSocketTest, SetMaxMessageSize) {
1616 SocketUnderTest a("A");
1617
1618 a.socket.SetMaxMessageSize(42u);
1619 EXPECT_EQ(a.socket.options().max_message_size, 42u);
1620 }
1621
TEST_P(DcSctpSocketParametrizedTest,SendsMessagesWithLowLifetime)1622 TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) {
1623 SocketUnderTest a("A");
1624 auto z = std::make_unique<SocketUnderTest>("Z");
1625
1626 ConnectSockets(a, *z);
1627 z = MaybeHandoverSocket(std::move(z));
1628
1629 // Mock that the time always goes forward.
1630 TimeMs now(0);
1631 EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() {
1632 now += DurationMs(3);
1633 return now;
1634 });
1635 EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() {
1636 now += DurationMs(3);
1637 return now;
1638 });
1639
1640 // Queue a few small messages with low lifetime, both ordered and unordered,
1641 // and validate that all are delivered.
1642 static constexpr int kIterations = 100;
1643 for (int i = 0; i < kIterations; ++i) {
1644 SendOptions send_options;
1645 send_options.unordered = IsUnordered((i % 2) == 0);
1646 send_options.lifetime = DurationMs(i % 3); // 0, 1, 2 ms
1647
1648 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
1649 }
1650
1651 ExchangeMessages(a, *z);
1652
1653 for (int i = 0; i < kIterations; ++i) {
1654 EXPECT_TRUE(z->cb.ConsumeReceivedMessage().has_value());
1655 }
1656
1657 EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
1658
1659 // Validate that the sockets really make the time move forward.
1660 EXPECT_GE(*now, kIterations * 2);
1661
1662 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1663 }
1664
TEST_P(DcSctpSocketParametrizedTest,DiscardsMessagesWithLowLifetimeIfMustBuffer)1665 TEST_P(DcSctpSocketParametrizedTest,
1666 DiscardsMessagesWithLowLifetimeIfMustBuffer) {
1667 SocketUnderTest a("A");
1668 auto z = std::make_unique<SocketUnderTest>("Z");
1669
1670 ConnectSockets(a, *z);
1671 z = MaybeHandoverSocket(std::move(z));
1672
1673 SendOptions lifetime_0;
1674 lifetime_0.unordered = IsUnordered(true);
1675 lifetime_0.lifetime = DurationMs(0);
1676
1677 SendOptions lifetime_1;
1678 lifetime_1.unordered = IsUnordered(true);
1679 lifetime_1.lifetime = DurationMs(1);
1680
1681 // Mock that the time always goes forward.
1682 TimeMs now(0);
1683 EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() {
1684 now += DurationMs(3);
1685 return now;
1686 });
1687 EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() {
1688 now += DurationMs(3);
1689 return now;
1690 });
1691
1692 // Fill up the send buffer with a large message.
1693 std::vector<uint8_t> payload(kLargeMessageSize);
1694 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
1695
1696 // And queue a few small messages with lifetime=0 or 1 ms - can't be sent.
1697 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0);
1698 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1);
1699 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0);
1700
1701 // Handle all that was sent until congestion window got full.
1702 for (;;) {
1703 std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
1704 if (packet_from_a.empty()) {
1705 break;
1706 }
1707 z->socket.ReceivePacket(std::move(packet_from_a));
1708 }
1709
1710 // Shouldn't be enough to send that large message.
1711 EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
1712
1713 // Exchange the rest of the messages, with the time ever increasing.
1714 ExchangeMessages(a, *z);
1715
1716 // The large message should be delivered. It was sent reliably.
1717 ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, z->cb.ConsumeReceivedMessage());
1718 EXPECT_EQ(m1.stream_id(), StreamID(1));
1719 EXPECT_THAT(m1.payload(), SizeIs(kLargeMessageSize));
1720
1721 // But none of the smaller messages.
1722 EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
1723
1724 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1725 }
1726
TEST_P(DcSctpSocketParametrizedTest,HasReasonableBufferedAmountValues)1727 TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) {
1728 SocketUnderTest a("A");
1729 auto z = std::make_unique<SocketUnderTest>("Z");
1730
1731 ConnectSockets(a, *z);
1732 z = MaybeHandoverSocket(std::move(z));
1733
1734 EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u);
1735
1736 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1737 std::vector<uint8_t>(kSmallMessageSize)),
1738 kSendOptions);
1739 // Sending a small message will directly send it as a single packet, so
1740 // nothing is left in the queue.
1741 EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u);
1742
1743 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1744 std::vector<uint8_t>(kLargeMessageSize)),
1745 kSendOptions);
1746
1747 // Sending a message will directly start sending a few packets, so the
1748 // buffered amount is not the full message size.
1749 EXPECT_GT(a.socket.buffered_amount(StreamID(1)), 0u);
1750 EXPECT_LT(a.socket.buffered_amount(StreamID(1)), kLargeMessageSize);
1751
1752 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1753 }
1754
TEST(DcSctpSocketTest,HasDefaultOnBufferedAmountLowValueZero)1755 TEST(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) {
1756 SocketUnderTest a("A");
1757 EXPECT_EQ(a.socket.buffered_amount_low_threshold(StreamID(1)), 0u);
1758 }
1759
TEST_P(DcSctpSocketParametrizedTest,TriggersOnBufferedAmountLowWithDefaultValueZero)1760 TEST_P(DcSctpSocketParametrizedTest,
1761 TriggersOnBufferedAmountLowWithDefaultValueZero) {
1762 SocketUnderTest a("A");
1763 auto z = std::make_unique<SocketUnderTest>("Z");
1764
1765 EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1766 ConnectSockets(a, *z);
1767 z = MaybeHandoverSocket(std::move(z));
1768
1769 EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1)));
1770 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1771 std::vector<uint8_t>(kSmallMessageSize)),
1772 kSendOptions);
1773 ExchangeMessages(a, *z);
1774
1775 EXPECT_CALL(a.cb, OnBufferedAmountLow).WillRepeatedly(testing::Return());
1776 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1777 }
1778
TEST_P(DcSctpSocketParametrizedTest,DoesntTriggerOnBufferedAmountLowIfBelowThreshold)1779 TEST_P(DcSctpSocketParametrizedTest,
1780 DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
1781 static constexpr size_t kMessageSize = 1000;
1782 static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10;
1783
1784 SocketUnderTest a("A");
1785 auto z = std::make_unique<SocketUnderTest>("Z");
1786
1787 a.socket.SetBufferedAmountLowThreshold(StreamID(1),
1788 kBufferedAmountLowThreshold);
1789 EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1790 ConnectSockets(a, *z);
1791 z = MaybeHandoverSocket(std::move(z));
1792
1793 EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(0);
1794 a.socket.Send(
1795 DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1796 kSendOptions);
1797 ExchangeMessages(a, *z);
1798
1799 a.socket.Send(
1800 DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1801 kSendOptions);
1802 ExchangeMessages(a, *z);
1803
1804 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1805 }
1806
TEST_P(DcSctpSocketParametrizedTest,TriggersOnBufferedAmountMultipleTimes)1807 TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) {
1808 static constexpr size_t kMessageSize = 1000;
1809 static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2;
1810
1811 SocketUnderTest a("A");
1812 auto z = std::make_unique<SocketUnderTest>("Z");
1813
1814 a.socket.SetBufferedAmountLowThreshold(StreamID(1),
1815 kBufferedAmountLowThreshold);
1816 EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1817 ConnectSockets(a, *z);
1818 z = MaybeHandoverSocket(std::move(z));
1819
1820 EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(3);
1821 EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(2))).Times(2);
1822 a.socket.Send(
1823 DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1824 kSendOptions);
1825 ExchangeMessages(a, *z);
1826
1827 a.socket.Send(
1828 DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)),
1829 kSendOptions);
1830 ExchangeMessages(a, *z);
1831
1832 a.socket.Send(
1833 DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1834 kSendOptions);
1835 ExchangeMessages(a, *z);
1836
1837 a.socket.Send(
1838 DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)),
1839 kSendOptions);
1840 ExchangeMessages(a, *z);
1841
1842 a.socket.Send(
1843 DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1844 kSendOptions);
1845 ExchangeMessages(a, *z);
1846
1847 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1848 }
1849
TEST_P(DcSctpSocketParametrizedTest,TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold)1850 TEST_P(DcSctpSocketParametrizedTest,
1851 TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
1852 static constexpr size_t kMessageSize = 1000;
1853 static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5;
1854
1855 SocketUnderTest a("A");
1856 auto z = std::make_unique<SocketUnderTest>("Z");
1857
1858 a.socket.SetBufferedAmountLowThreshold(StreamID(1),
1859 kBufferedAmountLowThreshold);
1860 EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1861 ConnectSockets(a, *z);
1862 z = MaybeHandoverSocket(std::move(z));
1863
1864 EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1865
1866 // Add a few messages to fill up the congestion window. When that is full,
1867 // messages will start to be fully buffered.
1868 while (a.socket.buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) {
1869 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1870 std::vector<uint8_t>(kMessageSize)),
1871 kSendOptions);
1872 }
1873 size_t initial_buffered = a.socket.buffered_amount(StreamID(1));
1874 ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold);
1875
1876 // Start ACKing packets, which will empty the send queue, and trigger the
1877 // callback.
1878 EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(1);
1879 ExchangeMessages(a, *z);
1880
1881 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1882 }
1883
TEST_P(DcSctpSocketParametrizedTest,DoesntTriggerOnTotalBufferAmountLowWhenBelow)1884 TEST_P(DcSctpSocketParametrizedTest,
1885 DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
1886 SocketUnderTest a("A");
1887 auto z = std::make_unique<SocketUnderTest>("Z");
1888
1889 ConnectSockets(a, *z);
1890 z = MaybeHandoverSocket(std::move(z));
1891
1892 EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0);
1893
1894 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1895 std::vector<uint8_t>(kLargeMessageSize)),
1896 kSendOptions);
1897
1898 ExchangeMessages(a, *z);
1899
1900 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1901 }
1902
TEST_P(DcSctpSocketParametrizedTest,TriggersOnTotalBufferAmountLowWhenCrossingThreshold)1903 TEST_P(DcSctpSocketParametrizedTest,
1904 TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
1905 SocketUnderTest a("A");
1906 auto z = std::make_unique<SocketUnderTest>("Z");
1907
1908 ConnectSockets(a, *z);
1909 z = MaybeHandoverSocket(std::move(z));
1910
1911 EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0);
1912
1913 // Fill up the send queue completely.
1914 for (;;) {
1915 if (a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1916 std::vector<uint8_t>(kLargeMessageSize)),
1917 kSendOptions) == SendStatus::kErrorResourceExhaustion) {
1918 break;
1919 }
1920 }
1921
1922 EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(1);
1923 ExchangeMessages(a, *z);
1924
1925 MaybeHandoverSocketAndSendMessage(a, std::move(z));
1926 }
1927
TEST(DcSctpSocketTest,InitialMetricsAreUnset)1928 TEST(DcSctpSocketTest, InitialMetricsAreUnset) {
1929 SocketUnderTest a("A");
1930
1931 EXPECT_FALSE(a.socket.GetMetrics().has_value());
1932 }
1933
TEST(DcSctpSocketTest,MessageInterleavingMetricsAreSet)1934 TEST(DcSctpSocketTest, MessageInterleavingMetricsAreSet) {
1935 std::vector<std::pair<bool, bool>> combinations = {
1936 {false, false}, {false, true}, {true, false}, {true, true}};
1937 for (const auto& [a_enable, z_enable] : combinations) {
1938 DcSctpOptions a_options = {.enable_message_interleaving = a_enable};
1939 DcSctpOptions z_options = {.enable_message_interleaving = z_enable};
1940
1941 SocketUnderTest a("A", a_options);
1942 SocketUnderTest z("Z", z_options);
1943 ConnectSockets(a, z);
1944
1945 EXPECT_EQ(a.socket.GetMetrics()->uses_message_interleaving,
1946 a_enable && z_enable);
1947 }
1948 }
1949
TEST(DcSctpSocketTest,RxAndTxPacketMetricsIncrease)1950 TEST(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) {
1951 SocketUnderTest a("A");
1952 SocketUnderTest z("Z");
1953
1954 ConnectSockets(a, z);
1955
1956 const size_t initial_a_rwnd = a.options.max_receiver_window_buffer_size *
1957 ReassemblyQueue::kHighWatermarkLimit;
1958
1959 EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 2u);
1960 EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 2u);
1961 EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 0u);
1962 EXPECT_EQ(a.socket.GetMetrics()->cwnd_bytes,
1963 a.options.cwnd_mtus_initial * a.options.mtu);
1964 EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
1965
1966 EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 2u);
1967 EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 0u);
1968
1969 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
1970 EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u);
1971
1972 z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
1973 a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK
1974 EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
1975 EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
1976
1977 EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value());
1978
1979 EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 3u);
1980 EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 3u);
1981 EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 1u);
1982
1983 EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 3u);
1984 EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 1u);
1985
1986 // Send one more (large - fragmented), and receive the delayed SACK.
1987 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1988 std::vector<uint8_t>(a.options.mtu * 2 + 1)),
1989 kSendOptions);
1990 EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 3u);
1991
1992 z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
1993 z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
1994
1995 a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK
1996 EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u);
1997 EXPECT_GT(a.socket.GetMetrics()->peer_rwnd_bytes, 0u);
1998 EXPECT_LT(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
1999
2000 z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
2001
2002 EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value());
2003
2004 EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 6u);
2005 EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 4u);
2006 EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 2u);
2007
2008 EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 6u);
2009 EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 2u);
2010
2011 // Delayed sack
2012 AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
2013
2014 a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK
2015 EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
2016 EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 5u);
2017 EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
2018 }
2019
TEST_P(DcSctpSocketParametrizedTest,UnackDataAlsoIncludesSendQueue)2020 TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) {
2021 SocketUnderTest a("A");
2022 auto z = std::make_unique<SocketUnderTest>("Z");
2023
2024 ConnectSockets(a, *z);
2025 z = MaybeHandoverSocket(std::move(z));
2026
2027 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
2028 std::vector<uint8_t>(kLargeMessageSize)),
2029 kSendOptions);
2030 size_t payload_bytes =
2031 a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
2032
2033 size_t expected_sent_packets = a.options.cwnd_mtus_initial;
2034
2035 size_t expected_queued_bytes =
2036 kLargeMessageSize - expected_sent_packets * payload_bytes;
2037
2038 size_t expected_queued_packets = expected_queued_bytes / payload_bytes;
2039
2040 // Due to alignment, padding etc, it's hard to calculate the exact number, but
2041 // it should be in this range.
2042 EXPECT_GE(a.socket.GetMetrics()->unack_data_count,
2043 expected_sent_packets + expected_queued_packets);
2044
2045 EXPECT_LE(a.socket.GetMetrics()->unack_data_count,
2046 expected_sent_packets + expected_queued_packets + 2);
2047
2048 MaybeHandoverSocketAndSendMessage(a, std::move(z));
2049 }
2050
TEST_P(DcSctpSocketParametrizedTest,DoesntSendMoreThanMaxBurstPackets)2051 TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) {
2052 SocketUnderTest a("A");
2053 auto z = std::make_unique<SocketUnderTest>("Z");
2054
2055 ConnectSockets(a, *z);
2056 z = MaybeHandoverSocket(std::move(z));
2057
2058 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
2059 std::vector<uint8_t>(kLargeMessageSize)),
2060 kSendOptions);
2061
2062 for (int i = 0; i < kMaxBurstPackets; ++i) {
2063 std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
2064 EXPECT_THAT(packet, Not(IsEmpty()));
2065 z->socket.ReceivePacket(std::move(packet)); // DATA
2066 }
2067
2068 EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
2069
2070 ExchangeMessages(a, *z);
2071 MaybeHandoverSocketAndSendMessage(a, std::move(z));
2072 }
2073
TEST_P(DcSctpSocketParametrizedTest,SendsOnlyLargePackets)2074 TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) {
2075 SocketUnderTest a("A");
2076 auto z = std::make_unique<SocketUnderTest>("Z");
2077
2078 ConnectSockets(a, *z);
2079 z = MaybeHandoverSocket(std::move(z));
2080
2081 // A really large message, to ensure that the congestion window is often full.
2082 constexpr size_t kMessageSize = 100000;
2083 a.socket.Send(
2084 DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
2085 kSendOptions);
2086
2087 bool delivered_packet = false;
2088 std::vector<size_t> data_packet_sizes;
2089 do {
2090 delivered_packet = false;
2091 std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
2092 if (!packet_from_a.empty()) {
2093 data_packet_sizes.push_back(packet_from_a.size());
2094 delivered_packet = true;
2095 z->socket.ReceivePacket(std::move(packet_from_a));
2096 }
2097 std::vector<uint8_t> packet_from_z = z->cb.ConsumeSentPacket();
2098 if (!packet_from_z.empty()) {
2099 delivered_packet = true;
2100 a.socket.ReceivePacket(std::move(packet_from_z));
2101 }
2102 } while (delivered_packet);
2103
2104 size_t packet_payload_bytes =
2105 a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
2106 // +1 accounts for padding, and rounding up.
2107 size_t expected_packets =
2108 (kMessageSize + packet_payload_bytes - 1) / packet_payload_bytes + 1;
2109 EXPECT_THAT(data_packet_sizes, SizeIs(expected_packets));
2110
2111 // Remove the last size - it will be the remainder. But all other sizes should
2112 // be large.
2113 data_packet_sizes.pop_back();
2114
2115 for (size_t size : data_packet_sizes) {
2116 // The 4 is for padding/alignment.
2117 EXPECT_GE(size, a.options.mtu - 4);
2118 }
2119
2120 MaybeHandoverSocketAndSendMessage(a, std::move(z));
2121 }
2122
TEST(DcSctpSocketTest,SendMessagesAfterHandover)2123 TEST(DcSctpSocketTest, SendMessagesAfterHandover) {
2124 SocketUnderTest a("A");
2125 auto z = std::make_unique<SocketUnderTest>("Z");
2126
2127 ConnectSockets(a, *z);
2128
2129 // Send message before handover to move socket to a not initial state
2130 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
2131 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
2132 z->cb.ConsumeReceivedMessage();
2133
2134 z = HandoverSocket(std::move(z));
2135
2136 absl::optional<DcSctpMessage> msg;
2137
2138 RTC_LOG(LS_INFO) << "Sending A #1";
2139
2140 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions);
2141 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
2142
2143 msg = z->cb.ConsumeReceivedMessage();
2144 ASSERT_TRUE(msg.has_value());
2145 EXPECT_EQ(msg->stream_id(), StreamID(1));
2146 EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4));
2147
2148 RTC_LOG(LS_INFO) << "Sending A #2";
2149
2150 a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions);
2151 z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
2152
2153 msg = z->cb.ConsumeReceivedMessage();
2154 ASSERT_TRUE(msg.has_value());
2155 EXPECT_EQ(msg->stream_id(), StreamID(2));
2156 EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6));
2157
2158 RTC_LOG(LS_INFO) << "Sending Z #1";
2159
2160 z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions);
2161 a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // ack
2162 a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // data
2163
2164 msg = a.cb.ConsumeReceivedMessage();
2165 ASSERT_TRUE(msg.has_value());
2166 EXPECT_EQ(msg->stream_id(), StreamID(1));
2167 EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3));
2168 }
2169
TEST(DcSctpSocketTest,CanDetectDcsctpImplementation)2170 TEST(DcSctpSocketTest, CanDetectDcsctpImplementation) {
2171 SocketUnderTest a("A");
2172 SocketUnderTest z("Z");
2173
2174 ConnectSockets(a, z);
2175
2176 EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp);
2177
2178 // As A initiated the connection establishment, Z will not receive enough
2179 // information to know about A's implementation
2180 EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kUnknown);
2181 }
2182
TEST(DcSctpSocketTest,BothCanDetectDcsctpImplementation)2183 TEST(DcSctpSocketTest, BothCanDetectDcsctpImplementation) {
2184 SocketUnderTest a("A");
2185 SocketUnderTest z("Z");
2186
2187 EXPECT_CALL(a.cb, OnConnected).Times(1);
2188 EXPECT_CALL(z.cb, OnConnected).Times(1);
2189 a.socket.Connect();
2190 z.socket.Connect();
2191
2192 ExchangeMessages(a, z);
2193
2194 EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp);
2195 EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kDcsctp);
2196 }
2197
TEST_P(DcSctpSocketParametrizedTest,CanLoseFirstOrderedMessage)2198 TEST_P(DcSctpSocketParametrizedTest, CanLoseFirstOrderedMessage) {
2199 SocketUnderTest a("A");
2200 auto z = std::make_unique<SocketUnderTest>("Z");
2201
2202 ConnectSockets(a, *z);
2203 z = MaybeHandoverSocket(std::move(z));
2204
2205 SendOptions send_options;
2206 send_options.unordered = IsUnordered(false);
2207 send_options.max_retransmissions = 0;
2208 std::vector<uint8_t> payload(a.options.mtu - 100);
2209
2210 // Send a first message (SID=1, SSN=0)
2211 a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
2212
2213 // First DATA is lost, and retransmission timer will delete it.
2214 a.cb.ConsumeSentPacket();
2215 AdvanceTime(a, *z, a.options.rto_initial);
2216 ExchangeMessages(a, *z);
2217
2218 // Send a second message (SID=0, SSN=1).
2219 a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
2220 ExchangeMessages(a, *z);
2221
2222 // The Z socket should receive the second message, but not the first.
2223 absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
2224 ASSERT_TRUE(msg.has_value());
2225 EXPECT_EQ(msg->ppid(), PPID(52));
2226
2227 EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
2228
2229 MaybeHandoverSocketAndSendMessage(a, std::move(z));
2230 }
2231
TEST(DcSctpSocketTest,ReceiveBothUnorderedAndOrderedWithSameTSN)2232 TEST(DcSctpSocketTest, ReceiveBothUnorderedAndOrderedWithSameTSN) {
2233 /* This issue was found by fuzzing. */
2234 SocketUnderTest a("A");
2235 SocketUnderTest z("Z");
2236
2237 a.socket.Connect();
2238 std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket();
2239 ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
2240 SctpPacket::Parse(init_data));
2241 ASSERT_HAS_VALUE_AND_ASSIGN(
2242 InitChunk init_chunk,
2243 InitChunk::Parse(init_packet.descriptors()[0].data));
2244 z.socket.ReceivePacket(init_data);
2245 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
2246 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
2247 a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
2248
2249 // Receive a short unordered message with tsn=INITIAL_TSN+1
2250 TSN tsn = init_chunk.initial_tsn();
2251 AnyDataChunk::Options opts;
2252 opts.is_beginning = Data::IsBeginning(true);
2253 opts.is_end = Data::IsEnd(true);
2254 opts.is_unordered = IsUnordered(true);
2255 z.socket.ReceivePacket(
2256 SctpPacket::Builder(z.socket.verification_tag(), z.options)
2257 .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53),
2258 std::vector<uint8_t>(10), opts))
2259 .Build());
2260
2261 // Now receive a longer _ordered_ message with [INITIAL_TSN, INITIAL_TSN+1].
2262 // This isn't allowed as it reuses TSN=53 with different properties, but it
2263 // shouldn't cause any issues.
2264 opts.is_unordered = IsUnordered(false);
2265 opts.is_end = Data::IsEnd(false);
2266 z.socket.ReceivePacket(
2267 SctpPacket::Builder(z.socket.verification_tag(), z.options)
2268 .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53),
2269 std::vector<uint8_t>(10), opts))
2270 .Build());
2271
2272 opts.is_beginning = Data::IsBeginning(false);
2273 opts.is_end = Data::IsEnd(true);
2274 z.socket.ReceivePacket(
2275 SctpPacket::Builder(z.socket.verification_tag(), z.options)
2276 .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53),
2277 std::vector<uint8_t>(10), opts))
2278 .Build());
2279 }
2280
TEST(DcSctpSocketTest,CloseTwoStreamsAtTheSameTime)2281 TEST(DcSctpSocketTest, CloseTwoStreamsAtTheSameTime) {
2282 // Reported as https://crbug.com/1312009.
2283 SocketUnderTest a("A");
2284 SocketUnderTest z("Z");
2285
2286 EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
2287 EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(2)))).Times(1);
2288 EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
2289 EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(2)))).Times(1);
2290
2291 ConnectSockets(a, z);
2292
2293 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
2294 a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2295
2296 ExchangeMessages(a, z);
2297
2298 a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2299 a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
2300
2301 ExchangeMessages(a, z);
2302 }
2303
TEST(DcSctpSocketTest,CloseThreeStreamsAtTheSameTime)2304 TEST(DcSctpSocketTest, CloseThreeStreamsAtTheSameTime) {
2305 // Similar to CloseTwoStreamsAtTheSameTime, but ensuring that the two
2306 // remaining streams are reset at the same time in the second request.
2307 SocketUnderTest a("A");
2308 SocketUnderTest z("Z");
2309
2310 EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
2311 EXPECT_CALL(z.cb, OnIncomingStreamsReset(
2312 UnorderedElementsAre(StreamID(2), StreamID(3))))
2313 .Times(1);
2314 EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
2315 EXPECT_CALL(a.cb, OnStreamsResetPerformed(
2316 UnorderedElementsAre(StreamID(2), StreamID(3))))
2317 .Times(1);
2318
2319 ConnectSockets(a, z);
2320
2321 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
2322 a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2323 a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), kSendOptions);
2324
2325 ExchangeMessages(a, z);
2326
2327 a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2328 a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
2329 a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
2330
2331 ExchangeMessages(a, z);
2332 }
2333
TEST(DcSctpSocketTest,CloseStreamsWithPendingRequest)2334 TEST(DcSctpSocketTest, CloseStreamsWithPendingRequest) {
2335 // Checks that stream reset requests are properly paused when they can't be
2336 // immediately reset - i.e. when there is already an ongoing stream reset
2337 // request (and there can only be a single one in-flight).
2338 SocketUnderTest a("A");
2339 SocketUnderTest z("Z");
2340
2341 EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
2342 EXPECT_CALL(z.cb, OnIncomingStreamsReset(
2343 UnorderedElementsAre(StreamID(2), StreamID(3))))
2344 .Times(1);
2345 EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
2346 EXPECT_CALL(a.cb, OnStreamsResetPerformed(
2347 UnorderedElementsAre(StreamID(2), StreamID(3))))
2348 .Times(1);
2349
2350 ConnectSockets(a, z);
2351
2352 SendOptions send_options = {.unordered = IsUnordered(false)};
2353
2354 // Send a few ordered messages
2355 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
2356 a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options);
2357 a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options);
2358
2359 ExchangeMessages(a, z);
2360
2361 // Receive these messages
2362 absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
2363 ASSERT_TRUE(msg1.has_value());
2364 EXPECT_EQ(msg1->stream_id(), StreamID(1));
2365 absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
2366 ASSERT_TRUE(msg2.has_value());
2367 EXPECT_EQ(msg2->stream_id(), StreamID(2));
2368 absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage();
2369 ASSERT_TRUE(msg3.has_value());
2370 EXPECT_EQ(msg3->stream_id(), StreamID(3));
2371
2372 // Reset the streams - not all at once.
2373 a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2374
2375 std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
2376 EXPECT_THAT(packet, HasReconfigWithStreams(ElementsAre(StreamID(1))));
2377 z.socket.ReceivePacket(std::move(packet));
2378
2379 // Sending more reset requests while this one is ongoing.
2380
2381 a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
2382 a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
2383
2384 ExchangeMessages(a, z);
2385
2386 // Send a few more ordered messages
2387 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
2388 a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options);
2389 a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options);
2390
2391 ExchangeMessages(a, z);
2392
2393 // Receive these messages
2394 absl::optional<DcSctpMessage> msg4 = z.cb.ConsumeReceivedMessage();
2395 ASSERT_TRUE(msg4.has_value());
2396 EXPECT_EQ(msg4->stream_id(), StreamID(1));
2397 absl::optional<DcSctpMessage> msg5 = z.cb.ConsumeReceivedMessage();
2398 ASSERT_TRUE(msg5.has_value());
2399 EXPECT_EQ(msg5->stream_id(), StreamID(2));
2400 absl::optional<DcSctpMessage> msg6 = z.cb.ConsumeReceivedMessage();
2401 ASSERT_TRUE(msg6.has_value());
2402 EXPECT_EQ(msg6->stream_id(), StreamID(3));
2403 }
2404
TEST(DcSctpSocketTest,StreamsHaveInitialPriority)2405 TEST(DcSctpSocketTest, StreamsHaveInitialPriority) {
2406 DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
2407 SocketUnderTest a("A", options);
2408
2409 EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)),
2410 options.default_stream_priority);
2411
2412 a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2413
2414 EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)),
2415 options.default_stream_priority);
2416 }
2417
TEST(DcSctpSocketTest,CanChangeStreamPriority)2418 TEST(DcSctpSocketTest, CanChangeStreamPriority) {
2419 DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
2420 SocketUnderTest a("A", options);
2421
2422 a.socket.SetStreamPriority(StreamID(1), StreamPriority(43));
2423 EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
2424
2425 a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2426
2427 a.socket.SetStreamPriority(StreamID(2), StreamPriority(43));
2428 EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
2429 }
2430
TEST_P(DcSctpSocketParametrizedTest,WillHandoverPriority)2431 TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) {
2432 DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
2433 auto a = std::make_unique<SocketUnderTest>("A", options);
2434 SocketUnderTest z("Z");
2435
2436 ConnectSockets(*a, z);
2437
2438 a->socket.SetStreamPriority(StreamID(1), StreamPriority(43));
2439 a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2440 a->socket.SetStreamPriority(StreamID(2), StreamPriority(43));
2441
2442 ExchangeMessages(*a, z);
2443
2444 a = MaybeHandoverSocket(std::move(a));
2445
2446 EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
2447 EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
2448 }
2449
TEST(DcSctpSocketTest,ReconnectSocketWithPendingStreamReset)2450 TEST(DcSctpSocketTest, ReconnectSocketWithPendingStreamReset) {
2451 // This is an issue found by fuzzing, and doesn't really make sense in WebRTC
2452 // data channels as a SCTP connection is never ever closed and then
2453 // reconnected. SCTP connections are closed when the peer connection is
2454 // deleted, and then it doesn't do more with SCTP.
2455 SocketUnderTest a("A");
2456 SocketUnderTest z("Z");
2457
2458 ConnectSockets(a, z);
2459
2460 a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2461
2462 EXPECT_CALL(z.cb, OnAborted).Times(1);
2463 a.socket.Close();
2464
2465 EXPECT_EQ(a.socket.state(), SocketState::kClosed);
2466
2467 EXPECT_CALL(a.cb, OnConnected).Times(1);
2468 EXPECT_CALL(z.cb, OnConnected).Times(1);
2469 a.socket.Connect();
2470 ExchangeMessages(a, z);
2471 a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
2472 }
2473
TEST(DcSctpSocketTest,SmallSentMessagesWithPrioWillArriveInSpecificOrder)2474 TEST(DcSctpSocketTest, SmallSentMessagesWithPrioWillArriveInSpecificOrder) {
2475 DcSctpOptions options = {.enable_message_interleaving = true};
2476 SocketUnderTest a("A", options);
2477 SocketUnderTest z("A", options);
2478
2479 a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
2480 a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
2481 a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
2482
2483 // Enqueue messages before connecting the socket, to ensure they aren't send
2484 // as soon as Send() is called.
2485 a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
2486 std::vector<uint8_t>(kSmallMessageSize)),
2487 kSendOptions);
2488 a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
2489 std::vector<uint8_t>(kSmallMessageSize)),
2490 kSendOptions);
2491 a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
2492 std::vector<uint8_t>(kSmallMessageSize)),
2493 kSendOptions);
2494 a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
2495 std::vector<uint8_t>(kSmallMessageSize)),
2496 kSendOptions);
2497 a.socket.Send(DcSctpMessage(StreamID(1), PPID(103),
2498 std::vector<uint8_t>(kSmallMessageSize)),
2499 kSendOptions);
2500
2501 ConnectSockets(a, z);
2502 ExchangeMessages(a, z);
2503
2504 std::vector<uint32_t> received_ppids;
2505 for (;;) {
2506 absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
2507 if (!msg.has_value()) {
2508 break;
2509 }
2510 received_ppids.push_back(*msg->ppid());
2511 }
2512
2513 EXPECT_THAT(received_ppids, ElementsAre(101, 102, 103, 201, 301));
2514 }
2515
TEST(DcSctpSocketTest,LargeSentMessagesWithPrioWillArriveInSpecificOrder)2516 TEST(DcSctpSocketTest, LargeSentMessagesWithPrioWillArriveInSpecificOrder) {
2517 DcSctpOptions options = {.enable_message_interleaving = true};
2518 SocketUnderTest a("A", options);
2519 SocketUnderTest z("A", options);
2520
2521 a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
2522 a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
2523 a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
2524
2525 // Enqueue messages before connecting the socket, to ensure they aren't send
2526 // as soon as Send() is called.
2527 a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
2528 std::vector<uint8_t>(kLargeMessageSize)),
2529 kSendOptions);
2530 a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
2531 std::vector<uint8_t>(kLargeMessageSize)),
2532 kSendOptions);
2533 a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
2534 std::vector<uint8_t>(kLargeMessageSize)),
2535 kSendOptions);
2536 a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
2537 std::vector<uint8_t>(kLargeMessageSize)),
2538 kSendOptions);
2539
2540 ConnectSockets(a, z);
2541 ExchangeMessages(a, z);
2542
2543 EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201, 301));
2544 }
2545
TEST(DcSctpSocketTest,MessageWithHigherPrioWillInterruptLowerPrioMessage)2546 TEST(DcSctpSocketTest, MessageWithHigherPrioWillInterruptLowerPrioMessage) {
2547 DcSctpOptions options = {.enable_message_interleaving = true};
2548 SocketUnderTest a("A", options);
2549 SocketUnderTest z("Z", options);
2550
2551 ConnectSockets(a, z);
2552
2553 a.socket.SetStreamPriority(StreamID(2), StreamPriority(128));
2554 a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
2555 std::vector<uint8_t>(kLargeMessageSize)),
2556 kSendOptions);
2557
2558 // Due to a non-zero initial congestion window, the message will already start
2559 // to send, but will not succeed to be sent completely before filling the
2560 // congestion window or stopping due to reaching how many packets that can be
2561 // sent at once (max burst). The important thing is that the entire message
2562 // doesn't get sent in full.
2563
2564 // Now enqueue two messages; one small and one large higher priority message.
2565 a.socket.SetStreamPriority(StreamID(1), StreamPriority(512));
2566 a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
2567 std::vector<uint8_t>(kSmallMessageSize)),
2568 kSendOptions);
2569 a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
2570 std::vector<uint8_t>(kLargeMessageSize)),
2571 kSendOptions);
2572
2573 ExchangeMessages(a, z);
2574
2575 EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201));
2576 }
2577
TEST(DcSctpSocketTest,LifecycleEventsAreGeneratedForAckedMessages)2578 TEST(DcSctpSocketTest, LifecycleEventsAreGeneratedForAckedMessages) {
2579 SocketUnderTest a("A");
2580 SocketUnderTest z("Z");
2581 ConnectSockets(a, z);
2582
2583 a.socket.Send(DcSctpMessage(StreamID(2), PPID(101),
2584 std::vector<uint8_t>(kLargeMessageSize)),
2585 {.lifecycle_id = LifecycleId(41)});
2586
2587 a.socket.Send(DcSctpMessage(StreamID(2), PPID(102),
2588 std::vector<uint8_t>(kLargeMessageSize)),
2589 kSendOptions);
2590
2591 a.socket.Send(DcSctpMessage(StreamID(2), PPID(103),
2592 std::vector<uint8_t>(kLargeMessageSize)),
2593 {.lifecycle_id = LifecycleId(42)});
2594
2595 EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(41)));
2596 EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(41)));
2597 EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(42)));
2598 EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(42)));
2599 ExchangeMessages(a, z);
2600 // In case of delayed ack.
2601 AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
2602 ExchangeMessages(a, z);
2603
2604 EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 103));
2605 }
2606
TEST(DcSctpSocketTest,LifecycleEventsForFailMaxRetransmissions)2607 TEST(DcSctpSocketTest, LifecycleEventsForFailMaxRetransmissions) {
2608 SocketUnderTest a("A");
2609 SocketUnderTest z("Z");
2610 ConnectSockets(a, z);
2611
2612 std::vector<uint8_t> payload(a.options.mtu - 100);
2613 a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
2614 {
2615 .max_retransmissions = 0,
2616 .lifecycle_id = LifecycleId(1),
2617 });
2618 a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload),
2619 {
2620 .max_retransmissions = 0,
2621 .lifecycle_id = LifecycleId(2),
2622 });
2623 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload),
2624 {
2625 .max_retransmissions = 0,
2626 .lifecycle_id = LifecycleId(3),
2627 });
2628
2629 // First DATA
2630 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
2631 // Second DATA (lost)
2632 a.cb.ConsumeSentPacket();
2633
2634 EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(1)));
2635 EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
2636 EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(2),
2637 /*maybe_delivered=*/true));
2638 EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(2)));
2639 EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(3)));
2640 EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(3)));
2641 ExchangeMessages(a, z);
2642
2643 // Handle delayed SACK.
2644 AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
2645 ExchangeMessages(a, z);
2646
2647 // The chunk is now NACKed. Let the RTO expire, to discard the message.
2648 AdvanceTime(a, z, a.options.rto_initial);
2649 ExchangeMessages(a, z);
2650
2651 // Handle delayed SACK.
2652 AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
2653 ExchangeMessages(a, z);
2654
2655 EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(51, 53));
2656 }
2657
TEST(DcSctpSocketTest,LifecycleEventsForExpiredMessageWithRetransmitLimit)2658 TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithRetransmitLimit) {
2659 SocketUnderTest a("A");
2660 SocketUnderTest z("Z");
2661 ConnectSockets(a, z);
2662
2663 // Will not be able to send it in full within the congestion window, but will
2664 // need to wait for SACKs to be received for more fragments to be sent.
2665 std::vector<uint8_t> payload(kLargeMessageSize);
2666 a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
2667 {
2668 .max_retransmissions = 0,
2669 .lifecycle_id = LifecycleId(1),
2670 });
2671
2672 // First DATA
2673 z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
2674 // Second DATA (lost)
2675 a.cb.ConsumeSentPacket();
2676
2677 EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1),
2678 /*maybe_delivered=*/false));
2679 EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
2680 ExchangeMessages(a, z);
2681
2682 EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty());
2683 }
2684
TEST(DcSctpSocketTest,LifecycleEventsForExpiredMessageWithLifetimeLimit)2685 TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithLifetimeLimit) {
2686 SocketUnderTest a("A");
2687 SocketUnderTest z("Z");
2688
2689 // Send it before the socket is connected, to prevent it from being sent too
2690 // quickly. The idea is that it should be expired before even attempting to
2691 // send it in full.
2692 std::vector<uint8_t> payload(kSmallMessageSize);
2693 a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
2694 {
2695 .lifetime = DurationMs(100),
2696 .lifecycle_id = LifecycleId(1),
2697 });
2698
2699 AdvanceTime(a, z, DurationMs(200));
2700
2701 EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1),
2702 /*maybe_delivered=*/false));
2703 EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
2704 ConnectSockets(a, z);
2705 ExchangeMessages(a, z);
2706
2707 EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty());
2708 }
2709
TEST_P(DcSctpSocketParametrizedTest,ExposesTheNumberOfNegotiatedStreams)2710 TEST_P(DcSctpSocketParametrizedTest, ExposesTheNumberOfNegotiatedStreams) {
2711 DcSctpOptions options_a = {
2712 .announced_maximum_incoming_streams = 12,
2713 .announced_maximum_outgoing_streams = 45,
2714 };
2715 SocketUnderTest a("A", options_a);
2716
2717 DcSctpOptions options_z = {
2718 .announced_maximum_incoming_streams = 23,
2719 .announced_maximum_outgoing_streams = 34,
2720 };
2721 auto z = std::make_unique<SocketUnderTest>("Z", options_z);
2722
2723 ConnectSockets(a, *z);
2724 z = MaybeHandoverSocket(std::move(z));
2725
2726 ASSERT_HAS_VALUE_AND_ASSIGN(Metrics metrics_a, a.socket.GetMetrics());
2727 EXPECT_EQ(metrics_a.negotiated_maximum_incoming_streams, 12);
2728 EXPECT_EQ(metrics_a.negotiated_maximum_outgoing_streams, 23);
2729
2730 ASSERT_HAS_VALUE_AND_ASSIGN(Metrics metrics_z, z->socket.GetMetrics());
2731 EXPECT_EQ(metrics_z.negotiated_maximum_incoming_streams, 23);
2732 EXPECT_EQ(metrics_z.negotiated_maximum_outgoing_streams, 12);
2733 }
2734
TEST(DcSctpSocketTest,ResetStreamsDeferred)2735 TEST(DcSctpSocketTest, ResetStreamsDeferred) {
2736 // Guaranteed to be fragmented into two fragments.
2737 constexpr size_t kTwoFragmentsSize = DcSctpOptions::kMaxSafeMTUSize + 100;
2738
2739 SocketUnderTest a("A");
2740 SocketUnderTest z("Z");
2741
2742 ConnectSockets(a, z);
2743
2744 a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
2745 std::vector<uint8_t>(kTwoFragmentsSize)),
2746 {});
2747 a.socket.Send(DcSctpMessage(StreamID(1), PPID(54),
2748 std::vector<uint8_t>(kSmallMessageSize)),
2749 {});
2750
2751 a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2752
2753 auto data1 = a.cb.ConsumeSentPacket();
2754 auto data2 = a.cb.ConsumeSentPacket();
2755 auto data3 = a.cb.ConsumeSentPacket();
2756 auto reconfig = a.cb.ConsumeSentPacket();
2757
2758 EXPECT_THAT(data1, HasDataChunkWithSsn(SSN(0)));
2759 EXPECT_THAT(data2, HasDataChunkWithSsn(SSN(0)));
2760 EXPECT_THAT(data3, HasDataChunkWithSsn(SSN(1)));
2761 EXPECT_THAT(reconfig, HasReconfigWithStreams(ElementsAre(StreamID(1))));
2762
2763 // Receive them slightly out of order to make stream resetting deferred.
2764 z.socket.ReceivePacket(reconfig);
2765
2766 z.socket.ReceivePacket(data1);
2767 z.socket.ReceivePacket(data2);
2768 z.socket.ReceivePacket(data3);
2769
2770 absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
2771 ASSERT_TRUE(msg1.has_value());
2772 EXPECT_EQ(msg1->stream_id(), StreamID(1));
2773 EXPECT_EQ(msg1->ppid(), PPID(53));
2774 EXPECT_EQ(msg1->payload().size(), kTwoFragmentsSize);
2775
2776 absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
2777 ASSERT_TRUE(msg2.has_value());
2778 EXPECT_EQ(msg2->stream_id(), StreamID(1));
2779 EXPECT_EQ(msg2->ppid(), PPID(54));
2780 EXPECT_EQ(msg2->payload().size(), kSmallMessageSize);
2781
2782 EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1))));
2783 ExchangeMessages(a, z);
2784
2785 // Z sent "in progress", which will make A buffer packets until it's sure
2786 // that the reconfiguration has been applied. A will retry - wait for that.
2787 AdvanceTime(a, z, a.options.rto_initial);
2788
2789 auto reconfig2 = a.cb.ConsumeSentPacket();
2790 EXPECT_THAT(reconfig2, HasReconfigWithStreams(ElementsAre(StreamID(1))));
2791 EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1))));
2792 z.socket.ReceivePacket(reconfig2);
2793
2794 auto reconfig3 = z.cb.ConsumeSentPacket();
2795 EXPECT_THAT(reconfig3,
2796 HasReconfigWithResponse(
2797 ReconfigurationResponseParameter::Result::kSuccessPerformed));
2798 a.socket.ReceivePacket(reconfig3);
2799
2800 EXPECT_THAT(data1, HasDataChunkWithSsn(SSN(0)));
2801 EXPECT_THAT(data2, HasDataChunkWithSsn(SSN(0)));
2802 EXPECT_THAT(data3, HasDataChunkWithSsn(SSN(1)));
2803 EXPECT_THAT(reconfig, HasReconfigWithStreams(ElementsAre(StreamID(1))));
2804
2805 // Send a new message after the stream has been reset.
2806 a.socket.Send(DcSctpMessage(StreamID(1), PPID(55),
2807 std::vector<uint8_t>(kSmallMessageSize)),
2808 {});
2809 ExchangeMessages(a, z);
2810
2811 absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage();
2812 ASSERT_TRUE(msg3.has_value());
2813 EXPECT_EQ(msg3->stream_id(), StreamID(1));
2814 EXPECT_EQ(msg3->ppid(), PPID(55));
2815 EXPECT_EQ(msg3->payload().size(), kSmallMessageSize);
2816 }
2817
TEST(DcSctpSocketTest,ResetStreamsWithPausedSenderResumesWhenPerformed)2818 TEST(DcSctpSocketTest, ResetStreamsWithPausedSenderResumesWhenPerformed) {
2819 SocketUnderTest a("A");
2820 SocketUnderTest z("Z");
2821
2822 ConnectSockets(a, z);
2823
2824 a.socket.Send(DcSctpMessage(StreamID(1), PPID(51),
2825 std::vector<uint8_t>(kSmallMessageSize)),
2826 {});
2827
2828 a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2829
2830 // Will be queued, as the stream has an outstanding reset operation.
2831 a.socket.Send(DcSctpMessage(StreamID(1), PPID(52),
2832 std::vector<uint8_t>(kSmallMessageSize)),
2833 {});
2834
2835 EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1))));
2836 EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1))));
2837 ExchangeMessages(a, z);
2838
2839 absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
2840 ASSERT_TRUE(msg1.has_value());
2841 EXPECT_EQ(msg1->stream_id(), StreamID(1));
2842 EXPECT_EQ(msg1->ppid(), PPID(51));
2843 EXPECT_EQ(msg1->payload().size(), kSmallMessageSize);
2844
2845 absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
2846 ASSERT_TRUE(msg2.has_value());
2847 EXPECT_EQ(msg2->stream_id(), StreamID(1));
2848 EXPECT_EQ(msg2->ppid(), PPID(52));
2849 EXPECT_EQ(msg2->payload().size(), kSmallMessageSize);
2850 }
2851
2852 } // namespace
2853 } // namespace dcsctp
2854