xref: /aosp_15_r20/external/webrtc/net/dcsctp/socket/dcsctp_socket_test.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "net/dcsctp/socket/dcsctp_socket.h"
11 
12 #include <cstdint>
13 #include <deque>
14 #include <memory>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include "absl/flags/flag.h"
20 #include "absl/memory/memory.h"
21 #include "absl/strings/string_view.h"
22 #include "absl/types/optional.h"
23 #include "api/array_view.h"
24 #include "net/dcsctp/common/handover_testing.h"
25 #include "net/dcsctp/packet/chunk/chunk.h"
26 #include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
27 #include "net/dcsctp/packet/chunk/data_chunk.h"
28 #include "net/dcsctp/packet/chunk/data_common.h"
29 #include "net/dcsctp/packet/chunk/error_chunk.h"
30 #include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
31 #include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
32 #include "net/dcsctp/packet/chunk/idata_chunk.h"
33 #include "net/dcsctp/packet/chunk/init_chunk.h"
34 #include "net/dcsctp/packet/chunk/sack_chunk.h"
35 #include "net/dcsctp/packet/chunk/shutdown_chunk.h"
36 #include "net/dcsctp/packet/error_cause/error_cause.h"
37 #include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h"
38 #include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h"
39 #include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
40 #include "net/dcsctp/packet/parameter/parameter.h"
41 #include "net/dcsctp/packet/sctp_packet.h"
42 #include "net/dcsctp/packet/tlv_trait.h"
43 #include "net/dcsctp/public/dcsctp_message.h"
44 #include "net/dcsctp/public/dcsctp_options.h"
45 #include "net/dcsctp/public/dcsctp_socket.h"
46 #include "net/dcsctp/public/text_pcap_packet_observer.h"
47 #include "net/dcsctp/public/types.h"
48 #include "net/dcsctp/rx/reassembly_queue.h"
49 #include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h"
50 #include "net/dcsctp/testing/testing_macros.h"
51 #include "rtc_base/gunit.h"
52 #include "test/gmock.h"
53 
54 ABSL_FLAG(bool, dcsctp_capture_packets, false, "Print packet capture.");
55 
56 namespace dcsctp {
57 namespace {
58 using ::testing::_;
59 using ::testing::AllOf;
60 using ::testing::ElementsAre;
61 using ::testing::HasSubstr;
62 using ::testing::IsEmpty;
63 using ::testing::SizeIs;
64 using ::testing::UnorderedElementsAre;
65 
66 constexpr SendOptions kSendOptions;
67 constexpr size_t kLargeMessageSize = DcSctpOptions::kMaxSafeMTUSize * 20;
68 constexpr size_t kSmallMessageSize = 10;
69 constexpr int kMaxBurstPackets = 4;
70 
71 MATCHER_P(HasDataChunkWithStreamId, stream_id, "") {
72   absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
73   if (!packet.has_value()) {
74     *result_listener << "data didn't parse as an SctpPacket";
75     return false;
76   }
77 
78   if (packet->descriptors()[0].type != DataChunk::kType) {
79     *result_listener << "the first chunk in the packet is not a data chunk";
80     return false;
81   }
82 
83   absl::optional<DataChunk> dc =
84       DataChunk::Parse(packet->descriptors()[0].data);
85   if (!dc.has_value()) {
86     *result_listener << "The first chunk didn't parse as a data chunk";
87     return false;
88   }
89 
90   if (dc->stream_id() != stream_id) {
91     *result_listener << "the stream_id is " << *dc->stream_id();
92     return false;
93   }
94 
95   return true;
96 }
97 
98 MATCHER_P(HasDataChunkWithPPID, ppid, "") {
99   absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
100   if (!packet.has_value()) {
101     *result_listener << "data didn't parse as an SctpPacket";
102     return false;
103   }
104 
105   if (packet->descriptors()[0].type != DataChunk::kType) {
106     *result_listener << "the first chunk in the packet is not a data chunk";
107     return false;
108   }
109 
110   absl::optional<DataChunk> dc =
111       DataChunk::Parse(packet->descriptors()[0].data);
112   if (!dc.has_value()) {
113     *result_listener << "The first chunk didn't parse as a data chunk";
114     return false;
115   }
116 
117   if (dc->ppid() != ppid) {
118     *result_listener << "the ppid is " << *dc->ppid();
119     return false;
120   }
121 
122   return true;
123 }
124 
125 MATCHER_P(HasDataChunkWithSsn, ssn, "") {
126   absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
127   if (!packet.has_value()) {
128     *result_listener << "data didn't parse as an SctpPacket";
129     return false;
130   }
131 
132   if (packet->descriptors()[0].type != DataChunk::kType) {
133     *result_listener << "the first chunk in the packet is not a data chunk";
134     return false;
135   }
136 
137   absl::optional<DataChunk> dc =
138       DataChunk::Parse(packet->descriptors()[0].data);
139   if (!dc.has_value()) {
140     *result_listener << "The first chunk didn't parse as a data chunk";
141     return false;
142   }
143 
144   if (dc->ssn() != ssn) {
145     *result_listener << "the ssn is " << *dc->ssn();
146     return false;
147   }
148 
149   return true;
150 }
151 
152 MATCHER_P(HasDataChunkWithMid, mid, "") {
153   absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
154   if (!packet.has_value()) {
155     *result_listener << "data didn't parse as an SctpPacket";
156     return false;
157   }
158 
159   if (packet->descriptors()[0].type != IDataChunk::kType) {
160     *result_listener << "the first chunk in the packet is not an i-data chunk";
161     return false;
162   }
163 
164   absl::optional<IDataChunk> dc =
165       IDataChunk::Parse(packet->descriptors()[0].data);
166   if (!dc.has_value()) {
167     *result_listener << "The first chunk didn't parse as an i-data chunk";
168     return false;
169   }
170 
171   if (dc->message_id() != mid) {
172     *result_listener << "the mid is " << *dc->message_id();
173     return false;
174   }
175 
176   return true;
177 }
178 
179 MATCHER_P(HasSackWithCumAckTsn, tsn, "") {
180   absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
181   if (!packet.has_value()) {
182     *result_listener << "data didn't parse as an SctpPacket";
183     return false;
184   }
185 
186   if (packet->descriptors()[0].type != SackChunk::kType) {
187     *result_listener << "the first chunk in the packet is not a data chunk";
188     return false;
189   }
190 
191   absl::optional<SackChunk> sc =
192       SackChunk::Parse(packet->descriptors()[0].data);
193   if (!sc.has_value()) {
194     *result_listener << "The first chunk didn't parse as a data chunk";
195     return false;
196   }
197 
198   if (sc->cumulative_tsn_ack() != tsn) {
199     *result_listener << "the cum_ack_tsn is " << *sc->cumulative_tsn_ack();
200     return false;
201   }
202 
203   return true;
204 }
205 
206 MATCHER(HasSackWithNoGapAckBlocks, "") {
207   absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
208   if (!packet.has_value()) {
209     *result_listener << "data didn't parse as an SctpPacket";
210     return false;
211   }
212 
213   if (packet->descriptors()[0].type != SackChunk::kType) {
214     *result_listener << "the first chunk in the packet is not a data chunk";
215     return false;
216   }
217 
218   absl::optional<SackChunk> sc =
219       SackChunk::Parse(packet->descriptors()[0].data);
220   if (!sc.has_value()) {
221     *result_listener << "The first chunk didn't parse as a data chunk";
222     return false;
223   }
224 
225   if (!sc->gap_ack_blocks().empty()) {
226     *result_listener << "there are gap ack blocks";
227     return false;
228   }
229 
230   return true;
231 }
232 
233 MATCHER_P(HasReconfigWithStreams, streams_matcher, "") {
234   absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
235   if (!packet.has_value()) {
236     *result_listener << "data didn't parse as an SctpPacket";
237     return false;
238   }
239 
240   if (packet->descriptors()[0].type != ReConfigChunk::kType) {
241     *result_listener << "the first chunk in the packet is not a data chunk";
242     return false;
243   }
244 
245   absl::optional<ReConfigChunk> reconfig =
246       ReConfigChunk::Parse(packet->descriptors()[0].data);
247   if (!reconfig.has_value()) {
248     *result_listener << "The first chunk didn't parse as a data chunk";
249     return false;
250   }
251 
252   const Parameters& parameters = reconfig->parameters();
253   if (parameters.descriptors().size() != 1 ||
254       parameters.descriptors()[0].type !=
255           OutgoingSSNResetRequestParameter::kType) {
256     *result_listener << "Expected the reconfig chunk to have an outgoing SSN "
257                         "reset request parameter";
258     return false;
259   }
260 
261   absl::optional<OutgoingSSNResetRequestParameter> p =
262       OutgoingSSNResetRequestParameter::Parse(parameters.descriptors()[0].data);
263   testing::Matcher<rtc::ArrayView<const StreamID>> matcher = streams_matcher;
264   if (!matcher.MatchAndExplain(p->stream_ids(), result_listener)) {
265     return false;
266   }
267 
268   return true;
269 }
270 
271 MATCHER_P(HasReconfigWithResponse, result, "") {
272   absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
273   if (!packet.has_value()) {
274     *result_listener << "data didn't parse as an SctpPacket";
275     return false;
276   }
277 
278   if (packet->descriptors()[0].type != ReConfigChunk::kType) {
279     *result_listener << "the first chunk in the packet is not a reconfig chunk";
280     return false;
281   }
282 
283   absl::optional<ReConfigChunk> reconfig =
284       ReConfigChunk::Parse(packet->descriptors()[0].data);
285   if (!reconfig.has_value()) {
286     *result_listener << "The first chunk didn't parse as a reconfig chunk";
287     return false;
288   }
289 
290   const Parameters& parameters = reconfig->parameters();
291   if (parameters.descriptors().size() != 1 ||
292       parameters.descriptors()[0].type !=
293           ReconfigurationResponseParameter::kType) {
294     *result_listener << "Expected the reconfig chunk to have a "
295                         "ReconfigurationResponse Parameter";
296     return false;
297   }
298 
299   absl::optional<ReconfigurationResponseParameter> p =
300       ReconfigurationResponseParameter::Parse(parameters.descriptors()[0].data);
301   if (p->result() != result) {
302     *result_listener << "ReconfigurationResponse Parameter doesn't contain the "
303                         "expected result";
304     return false;
305   }
306 
307   return true;
308 }
309 
AddTo(TSN tsn,int delta)310 TSN AddTo(TSN tsn, int delta) {
311   return TSN(*tsn + delta);
312 }
313 
FixupOptions(DcSctpOptions options={})314 DcSctpOptions FixupOptions(DcSctpOptions options = {}) {
315   DcSctpOptions fixup = options;
316   // To make the interval more predictable in tests.
317   fixup.heartbeat_interval_include_rtt = false;
318   fixup.max_burst = kMaxBurstPackets;
319   return fixup;
320 }
321 
GetPacketObserver(absl::string_view name)322 std::unique_ptr<PacketObserver> GetPacketObserver(absl::string_view name) {
323   if (absl::GetFlag(FLAGS_dcsctp_capture_packets)) {
324     return std::make_unique<TextPcapPacketObserver>(name);
325   }
326   return nullptr;
327 }
328 
329 struct SocketUnderTest {
SocketUnderTestdcsctp::__anon2dc231170111::SocketUnderTest330   explicit SocketUnderTest(absl::string_view name,
331                            const DcSctpOptions& opts = {})
332       : options(FixupOptions(opts)),
333         cb(name),
334         socket(name, cb, GetPacketObserver(name), options) {}
335 
336   const DcSctpOptions options;
337   testing::NiceMock<MockDcSctpSocketCallbacks> cb;
338   DcSctpSocket socket;
339 };
340 
ExchangeMessages(SocketUnderTest & a,SocketUnderTest & z)341 void ExchangeMessages(SocketUnderTest& a, SocketUnderTest& z) {
342   bool delivered_packet = false;
343   do {
344     delivered_packet = false;
345     std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
346     if (!packet_from_a.empty()) {
347       delivered_packet = true;
348       z.socket.ReceivePacket(std::move(packet_from_a));
349     }
350     std::vector<uint8_t> packet_from_z = z.cb.ConsumeSentPacket();
351     if (!packet_from_z.empty()) {
352       delivered_packet = true;
353       a.socket.ReceivePacket(std::move(packet_from_z));
354     }
355   } while (delivered_packet);
356 }
357 
RunTimers(SocketUnderTest & s)358 void RunTimers(SocketUnderTest& s) {
359   for (;;) {
360     absl::optional<TimeoutID> timeout_id = s.cb.GetNextExpiredTimeout();
361     if (!timeout_id.has_value()) {
362       break;
363     }
364     s.socket.HandleTimeout(*timeout_id);
365   }
366 }
367 
AdvanceTime(SocketUnderTest & a,SocketUnderTest & z,DurationMs duration)368 void AdvanceTime(SocketUnderTest& a, SocketUnderTest& z, DurationMs duration) {
369   a.cb.AdvanceTime(duration);
370   z.cb.AdvanceTime(duration);
371 
372   RunTimers(a);
373   RunTimers(z);
374 }
375 
376 // Calls Connect() on `sock_a_` and make the connection established.
ConnectSockets(SocketUnderTest & a,SocketUnderTest & z)377 void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) {
378   EXPECT_CALL(a.cb, OnConnected).Times(1);
379   EXPECT_CALL(z.cb, OnConnected).Times(1);
380 
381   a.socket.Connect();
382   // Z reads INIT, INIT_ACK, COOKIE_ECHO, COOKIE_ACK
383   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
384   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
385   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
386   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
387 
388   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
389   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
390 }
391 
HandoverSocket(std::unique_ptr<SocketUnderTest> sut)392 std::unique_ptr<SocketUnderTest> HandoverSocket(
393     std::unique_ptr<SocketUnderTest> sut) {
394   EXPECT_EQ(sut->socket.GetHandoverReadiness(), HandoverReadinessStatus());
395 
396   bool is_closed = sut->socket.state() == SocketState::kClosed;
397   if (!is_closed) {
398     EXPECT_CALL(sut->cb, OnClosed).Times(1);
399   }
400   absl::optional<DcSctpSocketHandoverState> handover_state =
401       sut->socket.GetHandoverStateAndClose();
402   EXPECT_TRUE(handover_state.has_value());
403   g_handover_state_transformer_for_test(&*handover_state);
404 
405   auto handover_socket = std::make_unique<SocketUnderTest>("H", sut->options);
406   if (!is_closed) {
407     EXPECT_CALL(handover_socket->cb, OnConnected).Times(1);
408   }
409   handover_socket->socket.RestoreFromState(*handover_state);
410   return handover_socket;
411 }
412 
GetReceivedMessagePpids(SocketUnderTest & z)413 std::vector<uint32_t> GetReceivedMessagePpids(SocketUnderTest& z) {
414   std::vector<uint32_t> ppids;
415   for (;;) {
416     absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
417     if (!msg.has_value()) {
418       break;
419     }
420     ppids.push_back(*msg->ppid());
421   }
422   return ppids;
423 }
424 
425 // Test parameter that controls whether to perform handovers during the test. A
426 // test can have multiple points where it conditionally hands over socket Z.
427 // Either socket Z will be handed over at all those points or handed over never.
428 enum class HandoverMode {
429   kNoHandover,
430   kPerformHandovers,
431 };
432 
433 class DcSctpSocketParametrizedTest
434     : public ::testing::Test,
435       public ::testing::WithParamInterface<HandoverMode> {
436  protected:
437   // Trigger handover for `sut` depending on the current test param.
MaybeHandoverSocket(std::unique_ptr<SocketUnderTest> sut)438   std::unique_ptr<SocketUnderTest> MaybeHandoverSocket(
439       std::unique_ptr<SocketUnderTest> sut) {
440     if (GetParam() == HandoverMode::kPerformHandovers) {
441       return HandoverSocket(std::move(sut));
442     }
443     return sut;
444   }
445 
446   // Trigger handover for socket Z depending on the current test param.
447   // Then checks message passing to verify the handed over socket is functional.
MaybeHandoverSocketAndSendMessage(SocketUnderTest & a,std::unique_ptr<SocketUnderTest> z)448   void MaybeHandoverSocketAndSendMessage(SocketUnderTest& a,
449                                          std::unique_ptr<SocketUnderTest> z) {
450     if (GetParam() == HandoverMode::kPerformHandovers) {
451       z = HandoverSocket(std::move(z));
452     }
453 
454     ExchangeMessages(a, *z);
455     a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
456     ExchangeMessages(a, *z);
457 
458     absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
459     ASSERT_TRUE(msg.has_value());
460     EXPECT_EQ(msg->stream_id(), StreamID(1));
461   }
462 };
463 
464 INSTANTIATE_TEST_SUITE_P(Handovers,
465                          DcSctpSocketParametrizedTest,
466                          testing::Values(HandoverMode::kNoHandover,
467                                          HandoverMode::kPerformHandovers),
__anon2dc231170202(const auto& test_info) 468                          [](const auto& test_info) {
469                            return test_info.param ==
470                                           HandoverMode::kPerformHandovers
471                                       ? "WithHandovers"
472                                       : "NoHandover";
473                          });
474 
TEST(DcSctpSocketTest,EstablishConnection)475 TEST(DcSctpSocketTest, EstablishConnection) {
476   SocketUnderTest a("A");
477   SocketUnderTest z("Z");
478 
479   EXPECT_CALL(a.cb, OnConnected).Times(1);
480   EXPECT_CALL(z.cb, OnConnected).Times(1);
481   EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
482   EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
483 
484   a.socket.Connect();
485   // Z reads INIT, produces INIT_ACK
486   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
487   // A reads INIT_ACK, produces COOKIE_ECHO
488   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
489   // Z reads COOKIE_ECHO, produces COOKIE_ACK
490   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
491   // A reads COOKIE_ACK.
492   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
493 
494   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
495   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
496 }
497 
TEST(DcSctpSocketTest,EstablishConnectionWithSetupCollision)498 TEST(DcSctpSocketTest, EstablishConnectionWithSetupCollision) {
499   SocketUnderTest a("A");
500   SocketUnderTest z("Z");
501 
502   EXPECT_CALL(a.cb, OnConnected).Times(1);
503   EXPECT_CALL(z.cb, OnConnected).Times(1);
504   EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
505   EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
506   a.socket.Connect();
507   z.socket.Connect();
508 
509   ExchangeMessages(a, z);
510 
511   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
512   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
513 }
514 
TEST(DcSctpSocketTest,ShuttingDownWhileEstablishingConnection)515 TEST(DcSctpSocketTest, ShuttingDownWhileEstablishingConnection) {
516   SocketUnderTest a("A");
517   SocketUnderTest z("Z");
518 
519   EXPECT_CALL(a.cb, OnConnected).Times(0);
520   EXPECT_CALL(z.cb, OnConnected).Times(1);
521   a.socket.Connect();
522 
523   // Z reads INIT, produces INIT_ACK
524   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
525   // A reads INIT_ACK, produces COOKIE_ECHO
526   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
527   // Z reads COOKIE_ECHO, produces COOKIE_ACK
528   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
529   // Drop COOKIE_ACK, just to more easily verify shutdown protocol.
530   z.cb.ConsumeSentPacket();
531 
532   // As Socket A has received INIT_ACK, it has a TCB and is connected, while
533   // Socket Z needs to receive COOKIE_ECHO to get there. Socket A still has
534   // timers running at this point.
535   EXPECT_EQ(a.socket.state(), SocketState::kConnecting);
536   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
537 
538   // Socket A is now shut down, which should make it stop those timers.
539   a.socket.Shutdown();
540 
541   EXPECT_CALL(a.cb, OnClosed).Times(1);
542   EXPECT_CALL(z.cb, OnClosed).Times(1);
543 
544   // Z reads SHUTDOWN, produces SHUTDOWN_ACK
545   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
546   // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE
547   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
548   // Z reads SHUTDOWN_COMPLETE.
549   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
550 
551   EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
552   EXPECT_TRUE(z.cb.ConsumeSentPacket().empty());
553 
554   EXPECT_EQ(a.socket.state(), SocketState::kClosed);
555   EXPECT_EQ(z.socket.state(), SocketState::kClosed);
556 }
557 
TEST(DcSctpSocketTest,EstablishSimultaneousConnection)558 TEST(DcSctpSocketTest, EstablishSimultaneousConnection) {
559   SocketUnderTest a("A");
560   SocketUnderTest z("Z");
561 
562   EXPECT_CALL(a.cb, OnConnected).Times(1);
563   EXPECT_CALL(z.cb, OnConnected).Times(1);
564   EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
565   EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
566   a.socket.Connect();
567 
568   // INIT isn't received by Z, as it wasn't ready yet.
569   a.cb.ConsumeSentPacket();
570 
571   z.socket.Connect();
572 
573   // A reads INIT, produces INIT_ACK
574   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
575 
576   // Z reads INIT_ACK, sends COOKIE_ECHO
577   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
578 
579   // A reads COOKIE_ECHO - establishes connection.
580   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
581 
582   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
583 
584   // Proceed with the remaining packets.
585   ExchangeMessages(a, z);
586 
587   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
588   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
589 }
590 
TEST(DcSctpSocketTest,EstablishConnectionLostCookieAck)591 TEST(DcSctpSocketTest, EstablishConnectionLostCookieAck) {
592   SocketUnderTest a("A");
593   SocketUnderTest z("Z");
594 
595   EXPECT_CALL(a.cb, OnConnected).Times(1);
596   EXPECT_CALL(z.cb, OnConnected).Times(1);
597   EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
598   EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
599 
600   a.socket.Connect();
601   // Z reads INIT, produces INIT_ACK
602   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
603   // A reads INIT_ACK, produces COOKIE_ECHO
604   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
605   // Z reads COOKIE_ECHO, produces COOKIE_ACK
606   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
607   // COOKIE_ACK is lost.
608   z.cb.ConsumeSentPacket();
609 
610   EXPECT_EQ(a.socket.state(), SocketState::kConnecting);
611   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
612 
613   // This will make A re-send the COOKIE_ECHO
614   AdvanceTime(a, z, DurationMs(a.options.t1_cookie_timeout));
615 
616   // Z reads COOKIE_ECHO, produces COOKIE_ACK
617   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
618   // A reads COOKIE_ACK.
619   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
620 
621   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
622   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
623 }
624 
TEST(DcSctpSocketTest,ResendInitAndEstablishConnection)625 TEST(DcSctpSocketTest, ResendInitAndEstablishConnection) {
626   SocketUnderTest a("A");
627   SocketUnderTest z("Z");
628 
629   a.socket.Connect();
630   // INIT is never received by Z.
631   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
632                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
633   EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType);
634 
635   AdvanceTime(a, z, a.options.t1_init_timeout);
636 
637   // Z reads INIT, produces INIT_ACK
638   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
639   // A reads INIT_ACK, produces COOKIE_ECHO
640   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
641   // Z reads COOKIE_ECHO, produces COOKIE_ACK
642   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
643   // A reads COOKIE_ACK.
644   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
645 
646   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
647   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
648 }
649 
TEST(DcSctpSocketTest,ResendingInitTooManyTimesAborts)650 TEST(DcSctpSocketTest, ResendingInitTooManyTimesAborts) {
651   SocketUnderTest a("A");
652   SocketUnderTest z("Z");
653 
654   a.socket.Connect();
655 
656   // INIT is never received by Z.
657   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
658                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
659   EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType);
660 
661   for (int i = 0; i < *a.options.max_init_retransmits; ++i) {
662     AdvanceTime(a, z, a.options.t1_init_timeout * (1 << i));
663 
664     // INIT is resent
665     ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet,
666                                 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
667     EXPECT_EQ(resent_init_packet.descriptors()[0].type, InitChunk::kType);
668   }
669 
670   // Another timeout, after the max init retransmits.
671   EXPECT_CALL(a.cb, OnAborted).Times(1);
672   AdvanceTime(
673       a, z, a.options.t1_init_timeout * (1 << *a.options.max_init_retransmits));
674 
675   EXPECT_EQ(a.socket.state(), SocketState::kClosed);
676 }
677 
TEST(DcSctpSocketTest,ResendCookieEchoAndEstablishConnection)678 TEST(DcSctpSocketTest, ResendCookieEchoAndEstablishConnection) {
679   SocketUnderTest a("A");
680   SocketUnderTest z("Z");
681 
682   a.socket.Connect();
683 
684   // Z reads INIT, produces INIT_ACK
685   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
686   // A reads INIT_ACK, produces COOKIE_ECHO
687   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
688 
689   // COOKIE_ECHO is never received by Z.
690   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
691                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
692   EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType);
693 
694   AdvanceTime(a, z, a.options.t1_init_timeout);
695 
696   // Z reads COOKIE_ECHO, produces COOKIE_ACK
697   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
698   // A reads COOKIE_ACK.
699   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
700 
701   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
702   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
703 }
704 
TEST(DcSctpSocketTest,ResendingCookieEchoTooManyTimesAborts)705 TEST(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) {
706   SocketUnderTest a("A");
707   SocketUnderTest z("Z");
708 
709   a.socket.Connect();
710 
711   // Z reads INIT, produces INIT_ACK
712   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
713   // A reads INIT_ACK, produces COOKIE_ECHO
714   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
715 
716   // COOKIE_ECHO is never received by Z.
717   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
718                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
719   EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType);
720 
721   for (int i = 0; i < *a.options.max_init_retransmits; ++i) {
722     AdvanceTime(a, z, a.options.t1_cookie_timeout * (1 << i));
723 
724     // COOKIE_ECHO is resent
725     ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet,
726                                 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
727     EXPECT_EQ(resent_init_packet.descriptors()[0].type, CookieEchoChunk::kType);
728   }
729 
730   // Another timeout, after the max init retransmits.
731   EXPECT_CALL(a.cb, OnAborted).Times(1);
732   AdvanceTime(
733       a, z,
734       a.options.t1_cookie_timeout * (1 << *a.options.max_init_retransmits));
735 
736   EXPECT_EQ(a.socket.state(), SocketState::kClosed);
737 }
738 
TEST(DcSctpSocketTest,DoesntSendMorePacketsUntilCookieAckHasBeenReceived)739 TEST(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) {
740   SocketUnderTest a("A");
741   SocketUnderTest z("Z");
742 
743   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
744                               std::vector<uint8_t>(kLargeMessageSize)),
745                 kSendOptions);
746   a.socket.Connect();
747 
748   // Z reads INIT, produces INIT_ACK
749   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
750   // A reads INIT_ACK, produces COOKIE_ECHO
751   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
752 
753   // COOKIE_ECHO is never received by Z.
754   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet1,
755                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
756   EXPECT_THAT(cookie_echo_packet1.descriptors(), SizeIs(2));
757   EXPECT_EQ(cookie_echo_packet1.descriptors()[0].type, CookieEchoChunk::kType);
758   EXPECT_EQ(cookie_echo_packet1.descriptors()[1].type, DataChunk::kType);
759 
760   EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
761 
762   // There are DATA chunks in the sent packet (that was lost), which means that
763   // the T3-RTX timer is running, but as the socket is in kCookieEcho state, it
764   // will be T1-COOKIE that drives retransmissions, so when the T3-RTX expires,
765   // nothing should be retransmitted.
766   ASSERT_TRUE(a.options.rto_initial < a.options.t1_cookie_timeout);
767   AdvanceTime(a, z, a.options.rto_initial);
768   EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
769 
770   // When T1-COOKIE expires, both the COOKIE-ECHO and DATA should be present.
771   AdvanceTime(a, z, a.options.t1_cookie_timeout - a.options.rto_initial);
772 
773   // And this COOKIE-ECHO and DATA is also lost - never received by Z.
774   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet2,
775                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
776   EXPECT_THAT(cookie_echo_packet2.descriptors(), SizeIs(2));
777   EXPECT_EQ(cookie_echo_packet2.descriptors()[0].type, CookieEchoChunk::kType);
778   EXPECT_EQ(cookie_echo_packet2.descriptors()[1].type, DataChunk::kType);
779 
780   EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
781 
782   // COOKIE_ECHO has exponential backoff.
783   AdvanceTime(a, z, a.options.t1_cookie_timeout * 2);
784 
785   // Z reads COOKIE_ECHO, produces COOKIE_ACK
786   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
787   // A reads COOKIE_ACK.
788   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
789 
790   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
791   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
792 
793   ExchangeMessages(a, z);
794   EXPECT_THAT(z.cb.ConsumeReceivedMessage()->payload(),
795               SizeIs(kLargeMessageSize));
796 }
797 
TEST_P(DcSctpSocketParametrizedTest,ShutdownConnection)798 TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) {
799   SocketUnderTest a("A");
800   auto z = std::make_unique<SocketUnderTest>("Z");
801 
802   ConnectSockets(a, *z);
803   z = MaybeHandoverSocket(std::move(z));
804 
805   RTC_LOG(LS_INFO) << "Shutting down";
806 
807   EXPECT_CALL(z->cb, OnClosed).Times(1);
808   a.socket.Shutdown();
809   // Z reads SHUTDOWN, produces SHUTDOWN_ACK
810   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
811   // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE
812   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
813   // Z reads SHUTDOWN_COMPLETE.
814   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
815 
816   EXPECT_EQ(a.socket.state(), SocketState::kClosed);
817   EXPECT_EQ(z->socket.state(), SocketState::kClosed);
818 
819   z = MaybeHandoverSocket(std::move(z));
820   EXPECT_EQ(z->socket.state(), SocketState::kClosed);
821 }
822 
TEST(DcSctpSocketTest,ShutdownTimerExpiresTooManyTimeClosesConnection)823 TEST(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) {
824   SocketUnderTest a("A");
825   SocketUnderTest z("Z");
826 
827   ConnectSockets(a, z);
828 
829   a.socket.Shutdown();
830   // Drop first SHUTDOWN packet.
831   a.cb.ConsumeSentPacket();
832 
833   EXPECT_EQ(a.socket.state(), SocketState::kShuttingDown);
834 
835   for (int i = 0; i < *a.options.max_retransmissions; ++i) {
836     AdvanceTime(a, z, DurationMs(a.options.rto_initial * (1 << i)));
837 
838     // Dropping every shutdown chunk.
839     ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
840                                 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
841     EXPECT_EQ(packet.descriptors()[0].type, ShutdownChunk::kType);
842     EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
843   }
844   // The last expiry, makes it abort the connection.
845   EXPECT_CALL(a.cb, OnAborted).Times(1);
846   AdvanceTime(a, z,
847               a.options.rto_initial * (1 << *a.options.max_retransmissions));
848 
849   EXPECT_EQ(a.socket.state(), SocketState::kClosed);
850   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
851                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
852   EXPECT_EQ(packet.descriptors()[0].type, AbortChunk::kType);
853   EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
854 }
855 
TEST(DcSctpSocketTest,EstablishConnectionWhileSendingData)856 TEST(DcSctpSocketTest, EstablishConnectionWhileSendingData) {
857   SocketUnderTest a("A");
858   SocketUnderTest z("Z");
859 
860   a.socket.Connect();
861 
862   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
863 
864   // Z reads INIT, produces INIT_ACK
865   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
866   // // A reads INIT_ACK, produces COOKIE_ECHO
867   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
868   // // Z reads COOKIE_ECHO, produces COOKIE_ACK
869   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
870   // // A reads COOKIE_ACK.
871   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
872 
873   EXPECT_EQ(a.socket.state(), SocketState::kConnected);
874   EXPECT_EQ(z.socket.state(), SocketState::kConnected);
875 
876   absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
877   ASSERT_TRUE(msg.has_value());
878   EXPECT_EQ(msg->stream_id(), StreamID(1));
879 }
880 
TEST(DcSctpSocketTest,SendMessageAfterEstablished)881 TEST(DcSctpSocketTest, SendMessageAfterEstablished) {
882   SocketUnderTest a("A");
883   SocketUnderTest z("Z");
884 
885   ConnectSockets(a, z);
886 
887   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
888   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
889 
890   absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
891   ASSERT_TRUE(msg.has_value());
892   EXPECT_EQ(msg->stream_id(), StreamID(1));
893 }
894 
TEST_P(DcSctpSocketParametrizedTest,TimeoutResendsPacket)895 TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) {
896   SocketUnderTest a("A");
897   auto z = std::make_unique<SocketUnderTest>("Z");
898 
899   ConnectSockets(a, *z);
900   z = MaybeHandoverSocket(std::move(z));
901 
902   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
903   a.cb.ConsumeSentPacket();
904 
905   RTC_LOG(LS_INFO) << "Advancing time";
906   AdvanceTime(a, *z, a.options.rto_initial);
907 
908   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
909 
910   absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
911   ASSERT_TRUE(msg.has_value());
912   EXPECT_EQ(msg->stream_id(), StreamID(1));
913 
914   MaybeHandoverSocketAndSendMessage(a, std::move(z));
915 }
916 
TEST_P(DcSctpSocketParametrizedTest,SendALotOfBytesMissedSecondPacket)917 TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) {
918   SocketUnderTest a("A");
919   auto z = std::make_unique<SocketUnderTest>("Z");
920 
921   ConnectSockets(a, *z);
922   z = MaybeHandoverSocket(std::move(z));
923 
924   std::vector<uint8_t> payload(kLargeMessageSize);
925   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
926 
927   // First DATA
928   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
929   // Second DATA (lost)
930   a.cb.ConsumeSentPacket();
931 
932   // Retransmit and handle the rest
933   ExchangeMessages(a, *z);
934 
935   absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
936   ASSERT_TRUE(msg.has_value());
937   EXPECT_EQ(msg->stream_id(), StreamID(1));
938   EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
939 
940   MaybeHandoverSocketAndSendMessage(a, std::move(z));
941 }
942 
TEST_P(DcSctpSocketParametrizedTest,SendingHeartbeatAnswersWithAck)943 TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) {
944   SocketUnderTest a("A");
945   auto z = std::make_unique<SocketUnderTest>("Z");
946 
947   ConnectSockets(a, *z);
948   z = MaybeHandoverSocket(std::move(z));
949 
950   // Inject a HEARTBEAT chunk
951   SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
952   uint8_t info[] = {1, 2, 3, 4};
953   Parameters::Builder params_builder;
954   params_builder.Add(HeartbeatInfoParameter(info));
955   b.Add(HeartbeatRequestChunk(params_builder.Build()));
956   a.socket.ReceivePacket(b.Build());
957 
958   // HEARTBEAT_ACK is sent as a reply. Capture it.
959   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket ack_packet,
960                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
961   ASSERT_THAT(ack_packet.descriptors(), SizeIs(1));
962   ASSERT_HAS_VALUE_AND_ASSIGN(
963       HeartbeatAckChunk ack,
964       HeartbeatAckChunk::Parse(ack_packet.descriptors()[0].data));
965   ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info());
966   EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4));
967 
968   MaybeHandoverSocketAndSendMessage(a, std::move(z));
969 }
970 
TEST_P(DcSctpSocketParametrizedTest,ExpectHeartbeatToBeSent)971 TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) {
972   SocketUnderTest a("A");
973   auto z = std::make_unique<SocketUnderTest>("Z");
974 
975   ConnectSockets(a, *z);
976   z = MaybeHandoverSocket(std::move(z));
977 
978   EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
979 
980   AdvanceTime(a, *z, a.options.heartbeat_interval);
981 
982   std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket();
983   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
984                               SctpPacket::Parse(hb_packet_raw));
985   ASSERT_THAT(hb_packet.descriptors(), SizeIs(1));
986   ASSERT_HAS_VALUE_AND_ASSIGN(
987       HeartbeatRequestChunk hb,
988       HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data));
989   ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, hb.info());
990 
991   // The info is a single 64-bit number.
992   EXPECT_THAT(hb.info()->info(), SizeIs(8));
993 
994   // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back.
995   z->socket.ReceivePacket(hb_packet_raw);
996   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
997 
998   MaybeHandoverSocketAndSendMessage(a, std::move(z));
999 }
1000 
TEST_P(DcSctpSocketParametrizedTest,CloseConnectionAfterTooManyLostHeartbeats)1001 TEST_P(DcSctpSocketParametrizedTest,
1002        CloseConnectionAfterTooManyLostHeartbeats) {
1003   SocketUnderTest a("A");
1004   auto z = std::make_unique<SocketUnderTest>("Z");
1005 
1006   ConnectSockets(a, *z);
1007   z = MaybeHandoverSocket(std::move(z));
1008 
1009   EXPECT_CALL(z->cb, OnClosed).Times(1);
1010   EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty());
1011   // Force-close socket Z so that it doesn't interfere from now on.
1012   z->socket.Close();
1013 
1014   DurationMs time_to_next_hearbeat = a.options.heartbeat_interval;
1015 
1016   for (int i = 0; i < *a.options.max_retransmissions; ++i) {
1017     RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending...";
1018     AdvanceTime(a, *z, time_to_next_hearbeat);
1019 
1020     // Dropping every heartbeat.
1021     ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
1022                                 SctpPacket::Parse(a.cb.ConsumeSentPacket()));
1023     EXPECT_EQ(hb_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
1024 
1025     RTC_LOG(LS_INFO) << "Letting the heartbeat expire.";
1026     AdvanceTime(a, *z, DurationMs(1000));
1027 
1028     time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000);
1029   }
1030 
1031   RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending...";
1032   AdvanceTime(a, *z, time_to_next_hearbeat);
1033 
1034   // Last heartbeat
1035   EXPECT_THAT(a.cb.ConsumeSentPacket(), Not(IsEmpty()));
1036 
1037   EXPECT_CALL(a.cb, OnAborted).Times(1);
1038   // Should suffice as exceeding RTO
1039   AdvanceTime(a, *z, DurationMs(1000));
1040 
1041   z = MaybeHandoverSocket(std::move(z));
1042 }
1043 
TEST_P(DcSctpSocketParametrizedTest,RecoversAfterASuccessfulAck)1044 TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) {
1045   SocketUnderTest a("A");
1046   auto z = std::make_unique<SocketUnderTest>("Z");
1047 
1048   ConnectSockets(a, *z);
1049   z = MaybeHandoverSocket(std::move(z));
1050 
1051   EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty());
1052   EXPECT_CALL(z->cb, OnClosed).Times(1);
1053   // Force-close socket Z so that it doesn't interfere from now on.
1054   z->socket.Close();
1055 
1056   DurationMs time_to_next_hearbeat = a.options.heartbeat_interval;
1057 
1058   for (int i = 0; i < *a.options.max_retransmissions; ++i) {
1059     AdvanceTime(a, *z, time_to_next_hearbeat);
1060 
1061     // Dropping every heartbeat.
1062     a.cb.ConsumeSentPacket();
1063 
1064     RTC_LOG(LS_INFO) << "Letting the heartbeat expire.";
1065     AdvanceTime(a, *z, DurationMs(1000));
1066 
1067     time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000);
1068   }
1069 
1070   RTC_LOG(LS_INFO) << "Getting the last heartbeat - and acking it";
1071   AdvanceTime(a, *z, time_to_next_hearbeat);
1072 
1073   std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket();
1074   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
1075                               SctpPacket::Parse(hb_packet_raw));
1076   ASSERT_THAT(hb_packet.descriptors(), SizeIs(1));
1077   ASSERT_HAS_VALUE_AND_ASSIGN(
1078       HeartbeatRequestChunk hb,
1079       HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data));
1080 
1081   SctpPacket::Builder b(a.socket.verification_tag(), a.options);
1082   b.Add(HeartbeatAckChunk(std::move(hb).extract_parameters()));
1083   a.socket.ReceivePacket(b.Build());
1084 
1085   // Should suffice as exceeding RTO - which will not fire.
1086   EXPECT_CALL(a.cb, OnAborted).Times(0);
1087   AdvanceTime(a, *z, DurationMs(1000));
1088 
1089   EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
1090 
1091   // Verify that we get new heartbeats again.
1092   RTC_LOG(LS_INFO) << "Expecting a new heartbeat";
1093   AdvanceTime(a, *z, time_to_next_hearbeat);
1094 
1095   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket another_packet,
1096                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
1097   EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
1098 }
1099 
TEST_P(DcSctpSocketParametrizedTest,ResetStream)1100 TEST_P(DcSctpSocketParametrizedTest, ResetStream) {
1101   SocketUnderTest a("A");
1102   auto z = std::make_unique<SocketUnderTest>("Z");
1103 
1104   ConnectSockets(a, *z);
1105   z = MaybeHandoverSocket(std::move(z));
1106 
1107   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {});
1108   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1109 
1110   absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
1111   ASSERT_TRUE(msg.has_value());
1112   EXPECT_EQ(msg->stream_id(), StreamID(1));
1113 
1114   // Handle SACK
1115   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1116 
1117   // Reset the outgoing stream. This will directly send a RE-CONFIG.
1118   a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
1119 
1120   // Receiving the packet will trigger a callback, indicating that A has
1121   // reset its stream. It will also send a RE-CONFIG with a response.
1122   EXPECT_CALL(z->cb, OnIncomingStreamsReset).Times(1);
1123   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1124 
1125   // Receiving a response will trigger a callback. Streams are now reset.
1126   EXPECT_CALL(a.cb, OnStreamsResetPerformed).Times(1);
1127   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1128 
1129   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1130 }
1131 
TEST_P(DcSctpSocketParametrizedTest,ResetStreamWillMakeChunksStartAtZeroSsn)1132 TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
1133   SocketUnderTest a("A");
1134   auto z = std::make_unique<SocketUnderTest>("Z");
1135 
1136   ConnectSockets(a, *z);
1137   z = MaybeHandoverSocket(std::move(z));
1138 
1139   std::vector<uint8_t> payload(a.options.mtu - 100);
1140 
1141   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1142   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1143 
1144   auto packet1 = a.cb.ConsumeSentPacket();
1145   EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0)));
1146   z->socket.ReceivePacket(packet1);
1147 
1148   auto packet2 = a.cb.ConsumeSentPacket();
1149   EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1)));
1150   z->socket.ReceivePacket(packet2);
1151 
1152   // Handle SACK
1153   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1154 
1155   absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
1156   ASSERT_TRUE(msg1.has_value());
1157   EXPECT_EQ(msg1->stream_id(), StreamID(1));
1158 
1159   absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
1160   ASSERT_TRUE(msg2.has_value());
1161   EXPECT_EQ(msg2->stream_id(), StreamID(1));
1162 
1163   // Reset the outgoing stream. This will directly send a RE-CONFIG.
1164   a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
1165   // RE-CONFIG, req
1166   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1167   // RE-CONFIG, resp
1168   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1169 
1170   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1171 
1172   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1173 
1174   auto packet3 = a.cb.ConsumeSentPacket();
1175   EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0)));
1176   z->socket.ReceivePacket(packet3);
1177 
1178   auto packet4 = a.cb.ConsumeSentPacket();
1179   EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1)));
1180   z->socket.ReceivePacket(packet4);
1181 
1182   // Handle SACK
1183   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1184 
1185   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1186 }
1187 
TEST_P(DcSctpSocketParametrizedTest,ResetStreamWillOnlyResetTheRequestedStreams)1188 TEST_P(DcSctpSocketParametrizedTest,
1189        ResetStreamWillOnlyResetTheRequestedStreams) {
1190   SocketUnderTest a("A");
1191   auto z = std::make_unique<SocketUnderTest>("Z");
1192 
1193   ConnectSockets(a, *z);
1194   z = MaybeHandoverSocket(std::move(z));
1195 
1196   std::vector<uint8_t> payload(a.options.mtu - 100);
1197 
1198   // Send two ordered messages on SID 1
1199   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1200   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1201 
1202   auto packet1 = a.cb.ConsumeSentPacket();
1203   EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1)));
1204   EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0)));
1205   z->socket.ReceivePacket(packet1);
1206 
1207   auto packet2 = a.cb.ConsumeSentPacket();
1208   EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1)));
1209   EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1)));
1210   z->socket.ReceivePacket(packet2);
1211 
1212   // Handle SACK
1213   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1214 
1215   // Do the same, for SID 3
1216   a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
1217   a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
1218   auto packet3 = a.cb.ConsumeSentPacket();
1219   EXPECT_THAT(packet3, HasDataChunkWithStreamId(StreamID(3)));
1220   EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0)));
1221   z->socket.ReceivePacket(packet3);
1222   auto packet4 = a.cb.ConsumeSentPacket();
1223   EXPECT_THAT(packet4, HasDataChunkWithStreamId(StreamID(3)));
1224   EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1)));
1225   z->socket.ReceivePacket(packet4);
1226   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1227 
1228   // Receive all messages.
1229   absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
1230   ASSERT_TRUE(msg1.has_value());
1231   EXPECT_EQ(msg1->stream_id(), StreamID(1));
1232 
1233   absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
1234   ASSERT_TRUE(msg2.has_value());
1235   EXPECT_EQ(msg2->stream_id(), StreamID(1));
1236 
1237   absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage();
1238   ASSERT_TRUE(msg3.has_value());
1239   EXPECT_EQ(msg3->stream_id(), StreamID(3));
1240 
1241   absl::optional<DcSctpMessage> msg4 = z->cb.ConsumeReceivedMessage();
1242   ASSERT_TRUE(msg4.has_value());
1243   EXPECT_EQ(msg4->stream_id(), StreamID(3));
1244 
1245   // Reset SID 1. This will directly send a RE-CONFIG.
1246   a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
1247   // RE-CONFIG, req
1248   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1249   // RE-CONFIG, resp
1250   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1251 
1252   // Send a message on SID 1 and 3 - SID 1 should not be reset, but 3 should.
1253   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
1254 
1255   a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
1256 
1257   auto packet5 = a.cb.ConsumeSentPacket();
1258   EXPECT_THAT(packet5, HasDataChunkWithStreamId(StreamID(1)));
1259   EXPECT_THAT(packet5, HasDataChunkWithSsn(SSN(2)));  // Unchanged.
1260   z->socket.ReceivePacket(packet5);
1261 
1262   auto packet6 = a.cb.ConsumeSentPacket();
1263   EXPECT_THAT(packet6, HasDataChunkWithStreamId(StreamID(3)));
1264   EXPECT_THAT(packet6, HasDataChunkWithSsn(SSN(0)));  // Reset.
1265   z->socket.ReceivePacket(packet6);
1266 
1267   // Handle SACK
1268   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1269 
1270   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1271 }
1272 
TEST_P(DcSctpSocketParametrizedTest,OnePeerReconnects)1273 TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) {
1274   SocketUnderTest a("A");
1275   auto z = std::make_unique<SocketUnderTest>("Z");
1276 
1277   ConnectSockets(a, *z);
1278   z = MaybeHandoverSocket(std::move(z));
1279 
1280   EXPECT_CALL(a.cb, OnConnectionRestarted).Times(1);
1281   // Let's be evil here - reconnect while a fragmented packet was about to be
1282   // sent. The receiving side should get it in full.
1283   std::vector<uint8_t> payload(kLargeMessageSize);
1284   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
1285 
1286   // First DATA
1287   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1288 
1289   // Create a new association, z2 - and don't use z anymore.
1290   SocketUnderTest z2("Z2");
1291   z2.socket.Connect();
1292 
1293   // Retransmit and handle the rest. As there will be some chunks in-flight that
1294   // have the wrong verification tag, those will yield errors.
1295   ExchangeMessages(a, z2);
1296 
1297   absl::optional<DcSctpMessage> msg = z2.cb.ConsumeReceivedMessage();
1298   ASSERT_TRUE(msg.has_value());
1299   EXPECT_EQ(msg->stream_id(), StreamID(1));
1300   EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
1301 }
1302 
TEST_P(DcSctpSocketParametrizedTest,SendMessageWithLimitedRtx)1303 TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) {
1304   SocketUnderTest a("A");
1305   auto z = std::make_unique<SocketUnderTest>("Z");
1306 
1307   ConnectSockets(a, *z);
1308   z = MaybeHandoverSocket(std::move(z));
1309 
1310   SendOptions send_options;
1311   send_options.max_retransmissions = 0;
1312   std::vector<uint8_t> payload(a.options.mtu - 100);
1313   a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
1314   a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
1315   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options);
1316 
1317   // First DATA
1318   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1319   // Second DATA (lost)
1320   a.cb.ConsumeSentPacket();
1321   // Third DATA
1322   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1323 
1324   // Handle SACK for first DATA
1325   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1326 
1327   // Handle delayed SACK for third DATA
1328   AdvanceTime(a, *z, a.options.delayed_ack_max_timeout);
1329 
1330   // Handle SACK for second DATA
1331   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1332 
1333   // Now the missing data chunk will be marked as nacked, but it might still be
1334   // in-flight and the reported gap could be due to out-of-order delivery. So
1335   // the RetransmissionQueue will not mark it as "to be retransmitted" until
1336   // after the t3-rtx timer has expired.
1337   AdvanceTime(a, *z, a.options.rto_initial);
1338 
1339   // The chunk will be marked as retransmitted, and then as abandoned, which
1340   // will trigger a FORWARD-TSN to be sent.
1341 
1342   // FORWARD-TSN (third)
1343   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
1344 
1345   // Which will trigger a SACK
1346   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
1347 
1348   absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
1349   ASSERT_TRUE(msg1.has_value());
1350   EXPECT_EQ(msg1->ppid(), PPID(51));
1351 
1352   absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
1353   ASSERT_TRUE(msg2.has_value());
1354   EXPECT_EQ(msg2->ppid(), PPID(53));
1355 
1356   absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage();
1357   EXPECT_FALSE(msg3.has_value());
1358 
1359   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1360 }
1361 
TEST_P(DcSctpSocketParametrizedTest,SendManyFragmentedMessagesWithLimitedRtx)1362 TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) {
1363   SocketUnderTest a("A");
1364   auto z = std::make_unique<SocketUnderTest>("Z");
1365 
1366   ConnectSockets(a, *z);
1367   z = MaybeHandoverSocket(std::move(z));
1368 
1369   SendOptions send_options;
1370   send_options.unordered = IsUnordered(true);
1371   send_options.max_retransmissions = 0;
1372   std::vector<uint8_t> payload(a.options.mtu * 2 - 100 /* margin */);
1373   // Sending first message
1374   a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
1375   // Sending second message
1376   a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
1377   // Sending third message
1378   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options);
1379   // Sending fourth message
1380   a.socket.Send(DcSctpMessage(StreamID(1), PPID(54), payload), send_options);
1381 
1382   // First DATA, first fragment
1383   std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
1384   EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51)));
1385   z->socket.ReceivePacket(std::move(packet));
1386 
1387   // First DATA, second fragment (lost)
1388   packet = a.cb.ConsumeSentPacket();
1389   EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51)));
1390 
1391   // Second DATA, first fragment
1392   packet = a.cb.ConsumeSentPacket();
1393   EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52)));
1394   z->socket.ReceivePacket(std::move(packet));
1395 
1396   // Second DATA, second fragment (lost)
1397   packet = a.cb.ConsumeSentPacket();
1398   EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52)));
1399   EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1400 
1401   // Third DATA, first fragment
1402   packet = a.cb.ConsumeSentPacket();
1403   EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53)));
1404   EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1405   z->socket.ReceivePacket(std::move(packet));
1406 
1407   // Third DATA, second fragment (lost)
1408   packet = a.cb.ConsumeSentPacket();
1409   EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53)));
1410   EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1411 
1412   // Fourth DATA, first fragment
1413   packet = a.cb.ConsumeSentPacket();
1414   EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54)));
1415   EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1416   z->socket.ReceivePacket(std::move(packet));
1417 
1418   // Fourth DATA, second fragment
1419   packet = a.cb.ConsumeSentPacket();
1420   EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54)));
1421   EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
1422   z->socket.ReceivePacket(std::move(packet));
1423 
1424   ExchangeMessages(a, *z);
1425 
1426   // Let the RTX timer expire, and exchange FORWARD-TSN/SACKs
1427   AdvanceTime(a, *z, a.options.rto_initial);
1428 
1429   ExchangeMessages(a, *z);
1430 
1431   absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
1432   ASSERT_TRUE(msg1.has_value());
1433   EXPECT_EQ(msg1->ppid(), PPID(54));
1434 
1435   ASSERT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
1436 
1437   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1438 }
1439 
1440 struct FakeChunkConfig : ChunkConfig {
1441   static constexpr int kType = 0x49;
1442   static constexpr size_t kHeaderSize = 4;
1443   static constexpr int kVariableLengthAlignment = 0;
1444 };
1445 
1446 class FakeChunk : public Chunk, public TLVTrait<FakeChunkConfig> {
1447  public:
FakeChunk()1448   FakeChunk() {}
1449 
1450   FakeChunk(FakeChunk&& other) = default;
1451   FakeChunk& operator=(FakeChunk&& other) = default;
1452 
SerializeTo(std::vector<uint8_t> & out) const1453   void SerializeTo(std::vector<uint8_t>& out) const override {
1454     AllocateTLV(out);
1455   }
ToString() const1456   std::string ToString() const override { return "FAKE"; }
1457 };
1458 
TEST_P(DcSctpSocketParametrizedTest,ReceivingUnknownChunkRespondsWithError)1459 TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) {
1460   SocketUnderTest a("A");
1461   auto z = std::make_unique<SocketUnderTest>("Z");
1462 
1463   ConnectSockets(a, *z);
1464   z = MaybeHandoverSocket(std::move(z));
1465 
1466   // Inject a FAKE chunk
1467   SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
1468   b.Add(FakeChunk());
1469   a.socket.ReceivePacket(b.Build());
1470 
1471   // ERROR is sent as a reply. Capture it.
1472   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket reply_packet,
1473                               SctpPacket::Parse(a.cb.ConsumeSentPacket()));
1474   ASSERT_THAT(reply_packet.descriptors(), SizeIs(1));
1475   ASSERT_HAS_VALUE_AND_ASSIGN(
1476       ErrorChunk error, ErrorChunk::Parse(reply_packet.descriptors()[0].data));
1477   ASSERT_HAS_VALUE_AND_ASSIGN(
1478       UnrecognizedChunkTypeCause cause,
1479       error.error_causes().get<UnrecognizedChunkTypeCause>());
1480   EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04));
1481 
1482   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1483 }
1484 
TEST_P(DcSctpSocketParametrizedTest,ReceivingErrorChunkReportsAsCallback)1485 TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) {
1486   SocketUnderTest a("A");
1487   auto z = std::make_unique<SocketUnderTest>("Z");
1488 
1489   ConnectSockets(a, *z);
1490   z = MaybeHandoverSocket(std::move(z));
1491 
1492   // Inject a ERROR chunk
1493   SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
1494   b.Add(
1495       ErrorChunk(Parameters::Builder()
1496                      .Add(UnrecognizedChunkTypeCause({0x49, 0x00, 0x00, 0x04}))
1497                      .Build()));
1498 
1499   EXPECT_CALL(a.cb, OnError(ErrorKind::kPeerReported,
1500                             HasSubstr("Unrecognized Chunk Type")));
1501   a.socket.ReceivePacket(b.Build());
1502 
1503   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1504 }
1505 
TEST(DcSctpSocketTest,PassingHighWatermarkWillOnlyAcceptCumAckTsn)1506 TEST(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) {
1507   SocketUnderTest a("A");
1508 
1509   constexpr size_t kReceiveWindowBufferSize = 2000;
1510   SocketUnderTest z(
1511       "Z", {.mtu = 3000,
1512             .max_receiver_window_buffer_size = kReceiveWindowBufferSize});
1513 
1514   EXPECT_CALL(z.cb, OnClosed).Times(0);
1515   EXPECT_CALL(z.cb, OnAborted).Times(0);
1516 
1517   a.socket.Connect();
1518   std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket();
1519   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
1520                               SctpPacket::Parse(init_data));
1521   ASSERT_HAS_VALUE_AND_ASSIGN(
1522       InitChunk init_chunk,
1523       InitChunk::Parse(init_packet.descriptors()[0].data));
1524   z.socket.ReceivePacket(init_data);
1525   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
1526   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
1527   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
1528 
1529   // Fill up Z2 to the high watermark limit.
1530   constexpr size_t kWatermarkLimit =
1531       kReceiveWindowBufferSize * ReassemblyQueue::kHighWatermarkLimit;
1532   constexpr size_t kRemainingSize = kReceiveWindowBufferSize - kWatermarkLimit;
1533 
1534   TSN tsn = init_chunk.initial_tsn();
1535   AnyDataChunk::Options opts;
1536   opts.is_beginning = Data::IsBeginning(true);
1537   z.socket.ReceivePacket(
1538       SctpPacket::Builder(z.socket.verification_tag(), z.options)
1539           .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53),
1540                          std::vector<uint8_t>(kWatermarkLimit + 1), opts))
1541           .Build());
1542 
1543   // First DATA will always trigger a SACK. It's not interesting.
1544   EXPECT_THAT(z.cb.ConsumeSentPacket(),
1545               AllOf(HasSackWithCumAckTsn(tsn), HasSackWithNoGapAckBlocks()));
1546 
1547   // This DATA should be accepted - it's advancing cum ack tsn.
1548   z.socket.ReceivePacket(
1549       SctpPacket::Builder(z.socket.verification_tag(), z.options)
1550           .Add(DataChunk(AddTo(tsn, 1), StreamID(1), SSN(0), PPID(53),
1551                          std::vector<uint8_t>(1),
1552                          /*options=*/{}))
1553           .Build());
1554 
1555   // The receiver might have moved into delayed ack mode.
1556   AdvanceTime(a, z, z.options.rto_initial);
1557 
1558   EXPECT_THAT(
1559       z.cb.ConsumeSentPacket(),
1560       AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks()));
1561 
1562   // This DATA will not be accepted - it's not advancing cum ack tsn.
1563   z.socket.ReceivePacket(
1564       SctpPacket::Builder(z.socket.verification_tag(), z.options)
1565           .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53),
1566                          std::vector<uint8_t>(1),
1567                          /*options=*/{}))
1568           .Build());
1569 
1570   // Sack will be sent in IMMEDIATE mode when this is happening.
1571   EXPECT_THAT(
1572       z.cb.ConsumeSentPacket(),
1573       AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks()));
1574 
1575   // This DATA will not be accepted either.
1576   z.socket.ReceivePacket(
1577       SctpPacket::Builder(z.socket.verification_tag(), z.options)
1578           .Add(DataChunk(AddTo(tsn, 4), StreamID(1), SSN(0), PPID(53),
1579                          std::vector<uint8_t>(1),
1580                          /*options=*/{}))
1581           .Build());
1582 
1583   // Sack will be sent in IMMEDIATE mode when this is happening.
1584   EXPECT_THAT(
1585       z.cb.ConsumeSentPacket(),
1586       AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks()));
1587 
1588   // This DATA should be accepted, and it fills the reassembly queue.
1589   z.socket.ReceivePacket(
1590       SctpPacket::Builder(z.socket.verification_tag(), z.options)
1591           .Add(DataChunk(AddTo(tsn, 2), StreamID(1), SSN(0), PPID(53),
1592                          std::vector<uint8_t>(kRemainingSize),
1593                          /*options=*/{}))
1594           .Build());
1595 
1596   // The receiver might have moved into delayed ack mode.
1597   AdvanceTime(a, z, z.options.rto_initial);
1598 
1599   EXPECT_THAT(
1600       z.cb.ConsumeSentPacket(),
1601       AllOf(HasSackWithCumAckTsn(AddTo(tsn, 2)), HasSackWithNoGapAckBlocks()));
1602 
1603   EXPECT_CALL(z.cb, OnAborted(ErrorKind::kResourceExhaustion, _));
1604   EXPECT_CALL(z.cb, OnClosed).Times(0);
1605 
1606   // This DATA will make the connection close. It's too full now.
1607   z.socket.ReceivePacket(
1608       SctpPacket::Builder(z.socket.verification_tag(), z.options)
1609           .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53),
1610                          std::vector<uint8_t>(kSmallMessageSize),
1611                          /*options=*/{}))
1612           .Build());
1613 }
1614 
TEST(DcSctpSocketTest,SetMaxMessageSize)1615 TEST(DcSctpSocketTest, SetMaxMessageSize) {
1616   SocketUnderTest a("A");
1617 
1618   a.socket.SetMaxMessageSize(42u);
1619   EXPECT_EQ(a.socket.options().max_message_size, 42u);
1620 }
1621 
TEST_P(DcSctpSocketParametrizedTest,SendsMessagesWithLowLifetime)1622 TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) {
1623   SocketUnderTest a("A");
1624   auto z = std::make_unique<SocketUnderTest>("Z");
1625 
1626   ConnectSockets(a, *z);
1627   z = MaybeHandoverSocket(std::move(z));
1628 
1629   // Mock that the time always goes forward.
1630   TimeMs now(0);
1631   EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() {
1632     now += DurationMs(3);
1633     return now;
1634   });
1635   EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() {
1636     now += DurationMs(3);
1637     return now;
1638   });
1639 
1640   // Queue a few small messages with low lifetime, both ordered and unordered,
1641   // and validate that all are delivered.
1642   static constexpr int kIterations = 100;
1643   for (int i = 0; i < kIterations; ++i) {
1644     SendOptions send_options;
1645     send_options.unordered = IsUnordered((i % 2) == 0);
1646     send_options.lifetime = DurationMs(i % 3);  // 0, 1, 2 ms
1647 
1648     a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
1649   }
1650 
1651   ExchangeMessages(a, *z);
1652 
1653   for (int i = 0; i < kIterations; ++i) {
1654     EXPECT_TRUE(z->cb.ConsumeReceivedMessage().has_value());
1655   }
1656 
1657   EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
1658 
1659   // Validate that the sockets really make the time move forward.
1660   EXPECT_GE(*now, kIterations * 2);
1661 
1662   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1663 }
1664 
TEST_P(DcSctpSocketParametrizedTest,DiscardsMessagesWithLowLifetimeIfMustBuffer)1665 TEST_P(DcSctpSocketParametrizedTest,
1666        DiscardsMessagesWithLowLifetimeIfMustBuffer) {
1667   SocketUnderTest a("A");
1668   auto z = std::make_unique<SocketUnderTest>("Z");
1669 
1670   ConnectSockets(a, *z);
1671   z = MaybeHandoverSocket(std::move(z));
1672 
1673   SendOptions lifetime_0;
1674   lifetime_0.unordered = IsUnordered(true);
1675   lifetime_0.lifetime = DurationMs(0);
1676 
1677   SendOptions lifetime_1;
1678   lifetime_1.unordered = IsUnordered(true);
1679   lifetime_1.lifetime = DurationMs(1);
1680 
1681   // Mock that the time always goes forward.
1682   TimeMs now(0);
1683   EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() {
1684     now += DurationMs(3);
1685     return now;
1686   });
1687   EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() {
1688     now += DurationMs(3);
1689     return now;
1690   });
1691 
1692   // Fill up the send buffer with a large message.
1693   std::vector<uint8_t> payload(kLargeMessageSize);
1694   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
1695 
1696   // And queue a few small messages with lifetime=0 or 1 ms - can't be sent.
1697   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0);
1698   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1);
1699   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0);
1700 
1701   // Handle all that was sent until congestion window got full.
1702   for (;;) {
1703     std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
1704     if (packet_from_a.empty()) {
1705       break;
1706     }
1707     z->socket.ReceivePacket(std::move(packet_from_a));
1708   }
1709 
1710   // Shouldn't be enough to send that large message.
1711   EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
1712 
1713   // Exchange the rest of the messages, with the time ever increasing.
1714   ExchangeMessages(a, *z);
1715 
1716   // The large message should be delivered. It was sent reliably.
1717   ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, z->cb.ConsumeReceivedMessage());
1718   EXPECT_EQ(m1.stream_id(), StreamID(1));
1719   EXPECT_THAT(m1.payload(), SizeIs(kLargeMessageSize));
1720 
1721   // But none of the smaller messages.
1722   EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
1723 
1724   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1725 }
1726 
TEST_P(DcSctpSocketParametrizedTest,HasReasonableBufferedAmountValues)1727 TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) {
1728   SocketUnderTest a("A");
1729   auto z = std::make_unique<SocketUnderTest>("Z");
1730 
1731   ConnectSockets(a, *z);
1732   z = MaybeHandoverSocket(std::move(z));
1733 
1734   EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u);
1735 
1736   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1737                               std::vector<uint8_t>(kSmallMessageSize)),
1738                 kSendOptions);
1739   // Sending a small message will directly send it as a single packet, so
1740   // nothing is left in the queue.
1741   EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u);
1742 
1743   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1744                               std::vector<uint8_t>(kLargeMessageSize)),
1745                 kSendOptions);
1746 
1747   // Sending a message will directly start sending a few packets, so the
1748   // buffered amount is not the full message size.
1749   EXPECT_GT(a.socket.buffered_amount(StreamID(1)), 0u);
1750   EXPECT_LT(a.socket.buffered_amount(StreamID(1)), kLargeMessageSize);
1751 
1752   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1753 }
1754 
TEST(DcSctpSocketTest,HasDefaultOnBufferedAmountLowValueZero)1755 TEST(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) {
1756   SocketUnderTest a("A");
1757   EXPECT_EQ(a.socket.buffered_amount_low_threshold(StreamID(1)), 0u);
1758 }
1759 
TEST_P(DcSctpSocketParametrizedTest,TriggersOnBufferedAmountLowWithDefaultValueZero)1760 TEST_P(DcSctpSocketParametrizedTest,
1761        TriggersOnBufferedAmountLowWithDefaultValueZero) {
1762   SocketUnderTest a("A");
1763   auto z = std::make_unique<SocketUnderTest>("Z");
1764 
1765   EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1766   ConnectSockets(a, *z);
1767   z = MaybeHandoverSocket(std::move(z));
1768 
1769   EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1)));
1770   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1771                               std::vector<uint8_t>(kSmallMessageSize)),
1772                 kSendOptions);
1773   ExchangeMessages(a, *z);
1774 
1775   EXPECT_CALL(a.cb, OnBufferedAmountLow).WillRepeatedly(testing::Return());
1776   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1777 }
1778 
TEST_P(DcSctpSocketParametrizedTest,DoesntTriggerOnBufferedAmountLowIfBelowThreshold)1779 TEST_P(DcSctpSocketParametrizedTest,
1780        DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
1781   static constexpr size_t kMessageSize = 1000;
1782   static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10;
1783 
1784   SocketUnderTest a("A");
1785   auto z = std::make_unique<SocketUnderTest>("Z");
1786 
1787   a.socket.SetBufferedAmountLowThreshold(StreamID(1),
1788                                          kBufferedAmountLowThreshold);
1789   EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1790   ConnectSockets(a, *z);
1791   z = MaybeHandoverSocket(std::move(z));
1792 
1793   EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(0);
1794   a.socket.Send(
1795       DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1796       kSendOptions);
1797   ExchangeMessages(a, *z);
1798 
1799   a.socket.Send(
1800       DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1801       kSendOptions);
1802   ExchangeMessages(a, *z);
1803 
1804   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1805 }
1806 
TEST_P(DcSctpSocketParametrizedTest,TriggersOnBufferedAmountMultipleTimes)1807 TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) {
1808   static constexpr size_t kMessageSize = 1000;
1809   static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2;
1810 
1811   SocketUnderTest a("A");
1812   auto z = std::make_unique<SocketUnderTest>("Z");
1813 
1814   a.socket.SetBufferedAmountLowThreshold(StreamID(1),
1815                                          kBufferedAmountLowThreshold);
1816   EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1817   ConnectSockets(a, *z);
1818   z = MaybeHandoverSocket(std::move(z));
1819 
1820   EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(3);
1821   EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(2))).Times(2);
1822   a.socket.Send(
1823       DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1824       kSendOptions);
1825   ExchangeMessages(a, *z);
1826 
1827   a.socket.Send(
1828       DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)),
1829       kSendOptions);
1830   ExchangeMessages(a, *z);
1831 
1832   a.socket.Send(
1833       DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1834       kSendOptions);
1835   ExchangeMessages(a, *z);
1836 
1837   a.socket.Send(
1838       DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)),
1839       kSendOptions);
1840   ExchangeMessages(a, *z);
1841 
1842   a.socket.Send(
1843       DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
1844       kSendOptions);
1845   ExchangeMessages(a, *z);
1846 
1847   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1848 }
1849 
TEST_P(DcSctpSocketParametrizedTest,TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold)1850 TEST_P(DcSctpSocketParametrizedTest,
1851        TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
1852   static constexpr size_t kMessageSize = 1000;
1853   static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5;
1854 
1855   SocketUnderTest a("A");
1856   auto z = std::make_unique<SocketUnderTest>("Z");
1857 
1858   a.socket.SetBufferedAmountLowThreshold(StreamID(1),
1859                                          kBufferedAmountLowThreshold);
1860   EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1861   ConnectSockets(a, *z);
1862   z = MaybeHandoverSocket(std::move(z));
1863 
1864   EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
1865 
1866   // Add a few messages to fill up the congestion window. When that is full,
1867   // messages will start to be fully buffered.
1868   while (a.socket.buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) {
1869     a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1870                                 std::vector<uint8_t>(kMessageSize)),
1871                   kSendOptions);
1872   }
1873   size_t initial_buffered = a.socket.buffered_amount(StreamID(1));
1874   ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold);
1875 
1876   // Start ACKing packets, which will empty the send queue, and trigger the
1877   // callback.
1878   EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(1);
1879   ExchangeMessages(a, *z);
1880 
1881   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1882 }
1883 
TEST_P(DcSctpSocketParametrizedTest,DoesntTriggerOnTotalBufferAmountLowWhenBelow)1884 TEST_P(DcSctpSocketParametrizedTest,
1885        DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
1886   SocketUnderTest a("A");
1887   auto z = std::make_unique<SocketUnderTest>("Z");
1888 
1889   ConnectSockets(a, *z);
1890   z = MaybeHandoverSocket(std::move(z));
1891 
1892   EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0);
1893 
1894   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1895                               std::vector<uint8_t>(kLargeMessageSize)),
1896                 kSendOptions);
1897 
1898   ExchangeMessages(a, *z);
1899 
1900   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1901 }
1902 
TEST_P(DcSctpSocketParametrizedTest,TriggersOnTotalBufferAmountLowWhenCrossingThreshold)1903 TEST_P(DcSctpSocketParametrizedTest,
1904        TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
1905   SocketUnderTest a("A");
1906   auto z = std::make_unique<SocketUnderTest>("Z");
1907 
1908   ConnectSockets(a, *z);
1909   z = MaybeHandoverSocket(std::move(z));
1910 
1911   EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0);
1912 
1913   // Fill up the send queue completely.
1914   for (;;) {
1915     if (a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1916                                     std::vector<uint8_t>(kLargeMessageSize)),
1917                       kSendOptions) == SendStatus::kErrorResourceExhaustion) {
1918       break;
1919     }
1920   }
1921 
1922   EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(1);
1923   ExchangeMessages(a, *z);
1924 
1925   MaybeHandoverSocketAndSendMessage(a, std::move(z));
1926 }
1927 
TEST(DcSctpSocketTest,InitialMetricsAreUnset)1928 TEST(DcSctpSocketTest, InitialMetricsAreUnset) {
1929   SocketUnderTest a("A");
1930 
1931   EXPECT_FALSE(a.socket.GetMetrics().has_value());
1932 }
1933 
TEST(DcSctpSocketTest,MessageInterleavingMetricsAreSet)1934 TEST(DcSctpSocketTest, MessageInterleavingMetricsAreSet) {
1935   std::vector<std::pair<bool, bool>> combinations = {
1936       {false, false}, {false, true}, {true, false}, {true, true}};
1937   for (const auto& [a_enable, z_enable] : combinations) {
1938     DcSctpOptions a_options = {.enable_message_interleaving = a_enable};
1939     DcSctpOptions z_options = {.enable_message_interleaving = z_enable};
1940 
1941     SocketUnderTest a("A", a_options);
1942     SocketUnderTest z("Z", z_options);
1943     ConnectSockets(a, z);
1944 
1945     EXPECT_EQ(a.socket.GetMetrics()->uses_message_interleaving,
1946               a_enable && z_enable);
1947   }
1948 }
1949 
TEST(DcSctpSocketTest,RxAndTxPacketMetricsIncrease)1950 TEST(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) {
1951   SocketUnderTest a("A");
1952   SocketUnderTest z("Z");
1953 
1954   ConnectSockets(a, z);
1955 
1956   const size_t initial_a_rwnd = a.options.max_receiver_window_buffer_size *
1957                                 ReassemblyQueue::kHighWatermarkLimit;
1958 
1959   EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 2u);
1960   EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 2u);
1961   EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 0u);
1962   EXPECT_EQ(a.socket.GetMetrics()->cwnd_bytes,
1963             a.options.cwnd_mtus_initial * a.options.mtu);
1964   EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
1965 
1966   EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 2u);
1967   EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 0u);
1968 
1969   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
1970   EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u);
1971 
1972   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());  // DATA
1973   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());  // SACK
1974   EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
1975   EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
1976 
1977   EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value());
1978 
1979   EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 3u);
1980   EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 3u);
1981   EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 1u);
1982 
1983   EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 3u);
1984   EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 1u);
1985 
1986   // Send one more (large - fragmented), and receive the delayed SACK.
1987   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
1988                               std::vector<uint8_t>(a.options.mtu * 2 + 1)),
1989                 kSendOptions);
1990   EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 3u);
1991 
1992   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());  // DATA
1993   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());  // DATA
1994 
1995   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());  // SACK
1996   EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u);
1997   EXPECT_GT(a.socket.GetMetrics()->peer_rwnd_bytes, 0u);
1998   EXPECT_LT(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
1999 
2000   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());  // DATA
2001 
2002   EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value());
2003 
2004   EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 6u);
2005   EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 4u);
2006   EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 2u);
2007 
2008   EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 6u);
2009   EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 2u);
2010 
2011   // Delayed sack
2012   AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
2013 
2014   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());  // SACK
2015   EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
2016   EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 5u);
2017   EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
2018 }
2019 
TEST_P(DcSctpSocketParametrizedTest,UnackDataAlsoIncludesSendQueue)2020 TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) {
2021   SocketUnderTest a("A");
2022   auto z = std::make_unique<SocketUnderTest>("Z");
2023 
2024   ConnectSockets(a, *z);
2025   z = MaybeHandoverSocket(std::move(z));
2026 
2027   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
2028                               std::vector<uint8_t>(kLargeMessageSize)),
2029                 kSendOptions);
2030   size_t payload_bytes =
2031       a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
2032 
2033   size_t expected_sent_packets = a.options.cwnd_mtus_initial;
2034 
2035   size_t expected_queued_bytes =
2036       kLargeMessageSize - expected_sent_packets * payload_bytes;
2037 
2038   size_t expected_queued_packets = expected_queued_bytes / payload_bytes;
2039 
2040   // Due to alignment, padding etc, it's hard to calculate the exact number, but
2041   // it should be in this range.
2042   EXPECT_GE(a.socket.GetMetrics()->unack_data_count,
2043             expected_sent_packets + expected_queued_packets);
2044 
2045   EXPECT_LE(a.socket.GetMetrics()->unack_data_count,
2046             expected_sent_packets + expected_queued_packets + 2);
2047 
2048   MaybeHandoverSocketAndSendMessage(a, std::move(z));
2049 }
2050 
TEST_P(DcSctpSocketParametrizedTest,DoesntSendMoreThanMaxBurstPackets)2051 TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) {
2052   SocketUnderTest a("A");
2053   auto z = std::make_unique<SocketUnderTest>("Z");
2054 
2055   ConnectSockets(a, *z);
2056   z = MaybeHandoverSocket(std::move(z));
2057 
2058   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
2059                               std::vector<uint8_t>(kLargeMessageSize)),
2060                 kSendOptions);
2061 
2062   for (int i = 0; i < kMaxBurstPackets; ++i) {
2063     std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
2064     EXPECT_THAT(packet, Not(IsEmpty()));
2065     z->socket.ReceivePacket(std::move(packet));  // DATA
2066   }
2067 
2068   EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
2069 
2070   ExchangeMessages(a, *z);
2071   MaybeHandoverSocketAndSendMessage(a, std::move(z));
2072 }
2073 
TEST_P(DcSctpSocketParametrizedTest,SendsOnlyLargePackets)2074 TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) {
2075   SocketUnderTest a("A");
2076   auto z = std::make_unique<SocketUnderTest>("Z");
2077 
2078   ConnectSockets(a, *z);
2079   z = MaybeHandoverSocket(std::move(z));
2080 
2081   // A really large message, to ensure that the congestion window is often full.
2082   constexpr size_t kMessageSize = 100000;
2083   a.socket.Send(
2084       DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
2085       kSendOptions);
2086 
2087   bool delivered_packet = false;
2088   std::vector<size_t> data_packet_sizes;
2089   do {
2090     delivered_packet = false;
2091     std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
2092     if (!packet_from_a.empty()) {
2093       data_packet_sizes.push_back(packet_from_a.size());
2094       delivered_packet = true;
2095       z->socket.ReceivePacket(std::move(packet_from_a));
2096     }
2097     std::vector<uint8_t> packet_from_z = z->cb.ConsumeSentPacket();
2098     if (!packet_from_z.empty()) {
2099       delivered_packet = true;
2100       a.socket.ReceivePacket(std::move(packet_from_z));
2101     }
2102   } while (delivered_packet);
2103 
2104   size_t packet_payload_bytes =
2105       a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
2106   // +1 accounts for padding, and rounding up.
2107   size_t expected_packets =
2108       (kMessageSize + packet_payload_bytes - 1) / packet_payload_bytes + 1;
2109   EXPECT_THAT(data_packet_sizes, SizeIs(expected_packets));
2110 
2111   // Remove the last size - it will be the remainder. But all other sizes should
2112   // be large.
2113   data_packet_sizes.pop_back();
2114 
2115   for (size_t size : data_packet_sizes) {
2116     // The 4 is for padding/alignment.
2117     EXPECT_GE(size, a.options.mtu - 4);
2118   }
2119 
2120   MaybeHandoverSocketAndSendMessage(a, std::move(z));
2121 }
2122 
TEST(DcSctpSocketTest,SendMessagesAfterHandover)2123 TEST(DcSctpSocketTest, SendMessagesAfterHandover) {
2124   SocketUnderTest a("A");
2125   auto z = std::make_unique<SocketUnderTest>("Z");
2126 
2127   ConnectSockets(a, *z);
2128 
2129   // Send message before handover to move socket to a not initial state
2130   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
2131   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
2132   z->cb.ConsumeReceivedMessage();
2133 
2134   z = HandoverSocket(std::move(z));
2135 
2136   absl::optional<DcSctpMessage> msg;
2137 
2138   RTC_LOG(LS_INFO) << "Sending A #1";
2139 
2140   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions);
2141   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
2142 
2143   msg = z->cb.ConsumeReceivedMessage();
2144   ASSERT_TRUE(msg.has_value());
2145   EXPECT_EQ(msg->stream_id(), StreamID(1));
2146   EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4));
2147 
2148   RTC_LOG(LS_INFO) << "Sending A #2";
2149 
2150   a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions);
2151   z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
2152 
2153   msg = z->cb.ConsumeReceivedMessage();
2154   ASSERT_TRUE(msg.has_value());
2155   EXPECT_EQ(msg->stream_id(), StreamID(2));
2156   EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6));
2157 
2158   RTC_LOG(LS_INFO) << "Sending Z #1";
2159 
2160   z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions);
2161   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());  // ack
2162   a.socket.ReceivePacket(z->cb.ConsumeSentPacket());  // data
2163 
2164   msg = a.cb.ConsumeReceivedMessage();
2165   ASSERT_TRUE(msg.has_value());
2166   EXPECT_EQ(msg->stream_id(), StreamID(1));
2167   EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3));
2168 }
2169 
TEST(DcSctpSocketTest,CanDetectDcsctpImplementation)2170 TEST(DcSctpSocketTest, CanDetectDcsctpImplementation) {
2171   SocketUnderTest a("A");
2172   SocketUnderTest z("Z");
2173 
2174   ConnectSockets(a, z);
2175 
2176   EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp);
2177 
2178   // As A initiated the connection establishment, Z will not receive enough
2179   // information to know about A's implementation
2180   EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kUnknown);
2181 }
2182 
TEST(DcSctpSocketTest,BothCanDetectDcsctpImplementation)2183 TEST(DcSctpSocketTest, BothCanDetectDcsctpImplementation) {
2184   SocketUnderTest a("A");
2185   SocketUnderTest z("Z");
2186 
2187   EXPECT_CALL(a.cb, OnConnected).Times(1);
2188   EXPECT_CALL(z.cb, OnConnected).Times(1);
2189   a.socket.Connect();
2190   z.socket.Connect();
2191 
2192   ExchangeMessages(a, z);
2193 
2194   EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp);
2195   EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kDcsctp);
2196 }
2197 
TEST_P(DcSctpSocketParametrizedTest,CanLoseFirstOrderedMessage)2198 TEST_P(DcSctpSocketParametrizedTest, CanLoseFirstOrderedMessage) {
2199   SocketUnderTest a("A");
2200   auto z = std::make_unique<SocketUnderTest>("Z");
2201 
2202   ConnectSockets(a, *z);
2203   z = MaybeHandoverSocket(std::move(z));
2204 
2205   SendOptions send_options;
2206   send_options.unordered = IsUnordered(false);
2207   send_options.max_retransmissions = 0;
2208   std::vector<uint8_t> payload(a.options.mtu - 100);
2209 
2210   // Send a first message (SID=1, SSN=0)
2211   a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
2212 
2213   // First DATA is lost, and retransmission timer will delete it.
2214   a.cb.ConsumeSentPacket();
2215   AdvanceTime(a, *z, a.options.rto_initial);
2216   ExchangeMessages(a, *z);
2217 
2218   // Send a second message (SID=0, SSN=1).
2219   a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
2220   ExchangeMessages(a, *z);
2221 
2222   // The Z socket should receive the second message, but not the first.
2223   absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
2224   ASSERT_TRUE(msg.has_value());
2225   EXPECT_EQ(msg->ppid(), PPID(52));
2226 
2227   EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
2228 
2229   MaybeHandoverSocketAndSendMessage(a, std::move(z));
2230 }
2231 
TEST(DcSctpSocketTest,ReceiveBothUnorderedAndOrderedWithSameTSN)2232 TEST(DcSctpSocketTest, ReceiveBothUnorderedAndOrderedWithSameTSN) {
2233   /* This issue was found by fuzzing. */
2234   SocketUnderTest a("A");
2235   SocketUnderTest z("Z");
2236 
2237   a.socket.Connect();
2238   std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket();
2239   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
2240                               SctpPacket::Parse(init_data));
2241   ASSERT_HAS_VALUE_AND_ASSIGN(
2242       InitChunk init_chunk,
2243       InitChunk::Parse(init_packet.descriptors()[0].data));
2244   z.socket.ReceivePacket(init_data);
2245   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
2246   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
2247   a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
2248 
2249   // Receive a short unordered message with tsn=INITIAL_TSN+1
2250   TSN tsn = init_chunk.initial_tsn();
2251   AnyDataChunk::Options opts;
2252   opts.is_beginning = Data::IsBeginning(true);
2253   opts.is_end = Data::IsEnd(true);
2254   opts.is_unordered = IsUnordered(true);
2255   z.socket.ReceivePacket(
2256       SctpPacket::Builder(z.socket.verification_tag(), z.options)
2257           .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53),
2258                          std::vector<uint8_t>(10), opts))
2259           .Build());
2260 
2261   // Now receive a longer _ordered_ message with [INITIAL_TSN, INITIAL_TSN+1].
2262   // This isn't allowed as it reuses TSN=53 with different properties, but it
2263   // shouldn't cause any issues.
2264   opts.is_unordered = IsUnordered(false);
2265   opts.is_end = Data::IsEnd(false);
2266   z.socket.ReceivePacket(
2267       SctpPacket::Builder(z.socket.verification_tag(), z.options)
2268           .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53),
2269                          std::vector<uint8_t>(10), opts))
2270           .Build());
2271 
2272   opts.is_beginning = Data::IsBeginning(false);
2273   opts.is_end = Data::IsEnd(true);
2274   z.socket.ReceivePacket(
2275       SctpPacket::Builder(z.socket.verification_tag(), z.options)
2276           .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53),
2277                          std::vector<uint8_t>(10), opts))
2278           .Build());
2279 }
2280 
TEST(DcSctpSocketTest,CloseTwoStreamsAtTheSameTime)2281 TEST(DcSctpSocketTest, CloseTwoStreamsAtTheSameTime) {
2282   // Reported as https://crbug.com/1312009.
2283   SocketUnderTest a("A");
2284   SocketUnderTest z("Z");
2285 
2286   EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
2287   EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(2)))).Times(1);
2288   EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
2289   EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(2)))).Times(1);
2290 
2291   ConnectSockets(a, z);
2292 
2293   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
2294   a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2295 
2296   ExchangeMessages(a, z);
2297 
2298   a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2299   a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
2300 
2301   ExchangeMessages(a, z);
2302 }
2303 
TEST(DcSctpSocketTest,CloseThreeStreamsAtTheSameTime)2304 TEST(DcSctpSocketTest, CloseThreeStreamsAtTheSameTime) {
2305   // Similar to CloseTwoStreamsAtTheSameTime, but ensuring that the two
2306   // remaining streams are reset at the same time in the second request.
2307   SocketUnderTest a("A");
2308   SocketUnderTest z("Z");
2309 
2310   EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
2311   EXPECT_CALL(z.cb, OnIncomingStreamsReset(
2312                         UnorderedElementsAre(StreamID(2), StreamID(3))))
2313       .Times(1);
2314   EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
2315   EXPECT_CALL(a.cb, OnStreamsResetPerformed(
2316                         UnorderedElementsAre(StreamID(2), StreamID(3))))
2317       .Times(1);
2318 
2319   ConnectSockets(a, z);
2320 
2321   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
2322   a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2323   a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), kSendOptions);
2324 
2325   ExchangeMessages(a, z);
2326 
2327   a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2328   a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
2329   a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
2330 
2331   ExchangeMessages(a, z);
2332 }
2333 
TEST(DcSctpSocketTest,CloseStreamsWithPendingRequest)2334 TEST(DcSctpSocketTest, CloseStreamsWithPendingRequest) {
2335   // Checks that stream reset requests are properly paused when they can't be
2336   // immediately reset - i.e. when there is already an ongoing stream reset
2337   // request (and there can only be a single one in-flight).
2338   SocketUnderTest a("A");
2339   SocketUnderTest z("Z");
2340 
2341   EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
2342   EXPECT_CALL(z.cb, OnIncomingStreamsReset(
2343                         UnorderedElementsAre(StreamID(2), StreamID(3))))
2344       .Times(1);
2345   EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
2346   EXPECT_CALL(a.cb, OnStreamsResetPerformed(
2347                         UnorderedElementsAre(StreamID(2), StreamID(3))))
2348       .Times(1);
2349 
2350   ConnectSockets(a, z);
2351 
2352   SendOptions send_options = {.unordered = IsUnordered(false)};
2353 
2354   // Send a few ordered messages
2355   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
2356   a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options);
2357   a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options);
2358 
2359   ExchangeMessages(a, z);
2360 
2361   // Receive these messages
2362   absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
2363   ASSERT_TRUE(msg1.has_value());
2364   EXPECT_EQ(msg1->stream_id(), StreamID(1));
2365   absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
2366   ASSERT_TRUE(msg2.has_value());
2367   EXPECT_EQ(msg2->stream_id(), StreamID(2));
2368   absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage();
2369   ASSERT_TRUE(msg3.has_value());
2370   EXPECT_EQ(msg3->stream_id(), StreamID(3));
2371 
2372   // Reset the streams - not all at once.
2373   a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2374 
2375   std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
2376   EXPECT_THAT(packet, HasReconfigWithStreams(ElementsAre(StreamID(1))));
2377   z.socket.ReceivePacket(std::move(packet));
2378 
2379   // Sending more reset requests while this one is ongoing.
2380 
2381   a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
2382   a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
2383 
2384   ExchangeMessages(a, z);
2385 
2386   // Send a few more ordered messages
2387   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
2388   a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options);
2389   a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options);
2390 
2391   ExchangeMessages(a, z);
2392 
2393   // Receive these messages
2394   absl::optional<DcSctpMessage> msg4 = z.cb.ConsumeReceivedMessage();
2395   ASSERT_TRUE(msg4.has_value());
2396   EXPECT_EQ(msg4->stream_id(), StreamID(1));
2397   absl::optional<DcSctpMessage> msg5 = z.cb.ConsumeReceivedMessage();
2398   ASSERT_TRUE(msg5.has_value());
2399   EXPECT_EQ(msg5->stream_id(), StreamID(2));
2400   absl::optional<DcSctpMessage> msg6 = z.cb.ConsumeReceivedMessage();
2401   ASSERT_TRUE(msg6.has_value());
2402   EXPECT_EQ(msg6->stream_id(), StreamID(3));
2403 }
2404 
TEST(DcSctpSocketTest,StreamsHaveInitialPriority)2405 TEST(DcSctpSocketTest, StreamsHaveInitialPriority) {
2406   DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
2407   SocketUnderTest a("A", options);
2408 
2409   EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)),
2410             options.default_stream_priority);
2411 
2412   a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2413 
2414   EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)),
2415             options.default_stream_priority);
2416 }
2417 
TEST(DcSctpSocketTest,CanChangeStreamPriority)2418 TEST(DcSctpSocketTest, CanChangeStreamPriority) {
2419   DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
2420   SocketUnderTest a("A", options);
2421 
2422   a.socket.SetStreamPriority(StreamID(1), StreamPriority(43));
2423   EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
2424 
2425   a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2426 
2427   a.socket.SetStreamPriority(StreamID(2), StreamPriority(43));
2428   EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
2429 }
2430 
TEST_P(DcSctpSocketParametrizedTest,WillHandoverPriority)2431 TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) {
2432   DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
2433   auto a = std::make_unique<SocketUnderTest>("A", options);
2434   SocketUnderTest z("Z");
2435 
2436   ConnectSockets(*a, z);
2437 
2438   a->socket.SetStreamPriority(StreamID(1), StreamPriority(43));
2439   a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
2440   a->socket.SetStreamPriority(StreamID(2), StreamPriority(43));
2441 
2442   ExchangeMessages(*a, z);
2443 
2444   a = MaybeHandoverSocket(std::move(a));
2445 
2446   EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
2447   EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
2448 }
2449 
TEST(DcSctpSocketTest,ReconnectSocketWithPendingStreamReset)2450 TEST(DcSctpSocketTest, ReconnectSocketWithPendingStreamReset) {
2451   // This is an issue found by fuzzing, and doesn't really make sense in WebRTC
2452   // data channels as a SCTP connection is never ever closed and then
2453   // reconnected. SCTP connections are closed when the peer connection is
2454   // deleted, and then it doesn't do more with SCTP.
2455   SocketUnderTest a("A");
2456   SocketUnderTest z("Z");
2457 
2458   ConnectSockets(a, z);
2459 
2460   a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2461 
2462   EXPECT_CALL(z.cb, OnAborted).Times(1);
2463   a.socket.Close();
2464 
2465   EXPECT_EQ(a.socket.state(), SocketState::kClosed);
2466 
2467   EXPECT_CALL(a.cb, OnConnected).Times(1);
2468   EXPECT_CALL(z.cb, OnConnected).Times(1);
2469   a.socket.Connect();
2470   ExchangeMessages(a, z);
2471   a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
2472 }
2473 
TEST(DcSctpSocketTest,SmallSentMessagesWithPrioWillArriveInSpecificOrder)2474 TEST(DcSctpSocketTest, SmallSentMessagesWithPrioWillArriveInSpecificOrder) {
2475   DcSctpOptions options = {.enable_message_interleaving = true};
2476   SocketUnderTest a("A", options);
2477   SocketUnderTest z("A", options);
2478 
2479   a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
2480   a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
2481   a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
2482 
2483   // Enqueue messages before connecting the socket, to ensure they aren't send
2484   // as soon as Send() is called.
2485   a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
2486                               std::vector<uint8_t>(kSmallMessageSize)),
2487                 kSendOptions);
2488   a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
2489                               std::vector<uint8_t>(kSmallMessageSize)),
2490                 kSendOptions);
2491   a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
2492                               std::vector<uint8_t>(kSmallMessageSize)),
2493                 kSendOptions);
2494   a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
2495                               std::vector<uint8_t>(kSmallMessageSize)),
2496                 kSendOptions);
2497   a.socket.Send(DcSctpMessage(StreamID(1), PPID(103),
2498                               std::vector<uint8_t>(kSmallMessageSize)),
2499                 kSendOptions);
2500 
2501   ConnectSockets(a, z);
2502   ExchangeMessages(a, z);
2503 
2504   std::vector<uint32_t> received_ppids;
2505   for (;;) {
2506     absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
2507     if (!msg.has_value()) {
2508       break;
2509     }
2510     received_ppids.push_back(*msg->ppid());
2511   }
2512 
2513   EXPECT_THAT(received_ppids, ElementsAre(101, 102, 103, 201, 301));
2514 }
2515 
TEST(DcSctpSocketTest,LargeSentMessagesWithPrioWillArriveInSpecificOrder)2516 TEST(DcSctpSocketTest, LargeSentMessagesWithPrioWillArriveInSpecificOrder) {
2517   DcSctpOptions options = {.enable_message_interleaving = true};
2518   SocketUnderTest a("A", options);
2519   SocketUnderTest z("A", options);
2520 
2521   a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
2522   a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
2523   a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
2524 
2525   // Enqueue messages before connecting the socket, to ensure they aren't send
2526   // as soon as Send() is called.
2527   a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
2528                               std::vector<uint8_t>(kLargeMessageSize)),
2529                 kSendOptions);
2530   a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
2531                               std::vector<uint8_t>(kLargeMessageSize)),
2532                 kSendOptions);
2533   a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
2534                               std::vector<uint8_t>(kLargeMessageSize)),
2535                 kSendOptions);
2536   a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
2537                               std::vector<uint8_t>(kLargeMessageSize)),
2538                 kSendOptions);
2539 
2540   ConnectSockets(a, z);
2541   ExchangeMessages(a, z);
2542 
2543   EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201, 301));
2544 }
2545 
TEST(DcSctpSocketTest,MessageWithHigherPrioWillInterruptLowerPrioMessage)2546 TEST(DcSctpSocketTest, MessageWithHigherPrioWillInterruptLowerPrioMessage) {
2547   DcSctpOptions options = {.enable_message_interleaving = true};
2548   SocketUnderTest a("A", options);
2549   SocketUnderTest z("Z", options);
2550 
2551   ConnectSockets(a, z);
2552 
2553   a.socket.SetStreamPriority(StreamID(2), StreamPriority(128));
2554   a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
2555                               std::vector<uint8_t>(kLargeMessageSize)),
2556                 kSendOptions);
2557 
2558   // Due to a non-zero initial congestion window, the message will already start
2559   // to send, but will not succeed to be sent completely before filling the
2560   // congestion window or stopping due to reaching how many packets that can be
2561   // sent at once (max burst). The important thing is that the entire message
2562   // doesn't get sent in full.
2563 
2564   // Now enqueue two messages; one small and one large higher priority message.
2565   a.socket.SetStreamPriority(StreamID(1), StreamPriority(512));
2566   a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
2567                               std::vector<uint8_t>(kSmallMessageSize)),
2568                 kSendOptions);
2569   a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
2570                               std::vector<uint8_t>(kLargeMessageSize)),
2571                 kSendOptions);
2572 
2573   ExchangeMessages(a, z);
2574 
2575   EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201));
2576 }
2577 
TEST(DcSctpSocketTest,LifecycleEventsAreGeneratedForAckedMessages)2578 TEST(DcSctpSocketTest, LifecycleEventsAreGeneratedForAckedMessages) {
2579   SocketUnderTest a("A");
2580   SocketUnderTest z("Z");
2581   ConnectSockets(a, z);
2582 
2583   a.socket.Send(DcSctpMessage(StreamID(2), PPID(101),
2584                               std::vector<uint8_t>(kLargeMessageSize)),
2585                 {.lifecycle_id = LifecycleId(41)});
2586 
2587   a.socket.Send(DcSctpMessage(StreamID(2), PPID(102),
2588                               std::vector<uint8_t>(kLargeMessageSize)),
2589                 kSendOptions);
2590 
2591   a.socket.Send(DcSctpMessage(StreamID(2), PPID(103),
2592                               std::vector<uint8_t>(kLargeMessageSize)),
2593                 {.lifecycle_id = LifecycleId(42)});
2594 
2595   EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(41)));
2596   EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(41)));
2597   EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(42)));
2598   EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(42)));
2599   ExchangeMessages(a, z);
2600   // In case of delayed ack.
2601   AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
2602   ExchangeMessages(a, z);
2603 
2604   EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 103));
2605 }
2606 
TEST(DcSctpSocketTest,LifecycleEventsForFailMaxRetransmissions)2607 TEST(DcSctpSocketTest, LifecycleEventsForFailMaxRetransmissions) {
2608   SocketUnderTest a("A");
2609   SocketUnderTest z("Z");
2610   ConnectSockets(a, z);
2611 
2612   std::vector<uint8_t> payload(a.options.mtu - 100);
2613   a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
2614                 {
2615                     .max_retransmissions = 0,
2616                     .lifecycle_id = LifecycleId(1),
2617                 });
2618   a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload),
2619                 {
2620                     .max_retransmissions = 0,
2621                     .lifecycle_id = LifecycleId(2),
2622                 });
2623   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload),
2624                 {
2625                     .max_retransmissions = 0,
2626                     .lifecycle_id = LifecycleId(3),
2627                 });
2628 
2629   // First DATA
2630   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
2631   // Second DATA (lost)
2632   a.cb.ConsumeSentPacket();
2633 
2634   EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(1)));
2635   EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
2636   EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(2),
2637                                               /*maybe_delivered=*/true));
2638   EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(2)));
2639   EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(3)));
2640   EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(3)));
2641   ExchangeMessages(a, z);
2642 
2643   // Handle delayed SACK.
2644   AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
2645   ExchangeMessages(a, z);
2646 
2647   // The chunk is now NACKed. Let the RTO expire, to discard the message.
2648   AdvanceTime(a, z, a.options.rto_initial);
2649   ExchangeMessages(a, z);
2650 
2651   // Handle delayed SACK.
2652   AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
2653   ExchangeMessages(a, z);
2654 
2655   EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(51, 53));
2656 }
2657 
TEST(DcSctpSocketTest,LifecycleEventsForExpiredMessageWithRetransmitLimit)2658 TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithRetransmitLimit) {
2659   SocketUnderTest a("A");
2660   SocketUnderTest z("Z");
2661   ConnectSockets(a, z);
2662 
2663   // Will not be able to send it in full within the congestion window, but will
2664   // need to wait for SACKs to be received for more fragments to be sent.
2665   std::vector<uint8_t> payload(kLargeMessageSize);
2666   a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
2667                 {
2668                     .max_retransmissions = 0,
2669                     .lifecycle_id = LifecycleId(1),
2670                 });
2671 
2672   // First DATA
2673   z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
2674   // Second DATA (lost)
2675   a.cb.ConsumeSentPacket();
2676 
2677   EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1),
2678                                               /*maybe_delivered=*/false));
2679   EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
2680   ExchangeMessages(a, z);
2681 
2682   EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty());
2683 }
2684 
TEST(DcSctpSocketTest,LifecycleEventsForExpiredMessageWithLifetimeLimit)2685 TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithLifetimeLimit) {
2686   SocketUnderTest a("A");
2687   SocketUnderTest z("Z");
2688 
2689   // Send it before the socket is connected, to prevent it from being sent too
2690   // quickly. The idea is that it should be expired before even attempting to
2691   // send it in full.
2692   std::vector<uint8_t> payload(kSmallMessageSize);
2693   a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
2694                 {
2695                     .lifetime = DurationMs(100),
2696                     .lifecycle_id = LifecycleId(1),
2697                 });
2698 
2699   AdvanceTime(a, z, DurationMs(200));
2700 
2701   EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1),
2702                                               /*maybe_delivered=*/false));
2703   EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
2704   ConnectSockets(a, z);
2705   ExchangeMessages(a, z);
2706 
2707   EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty());
2708 }
2709 
TEST_P(DcSctpSocketParametrizedTest,ExposesTheNumberOfNegotiatedStreams)2710 TEST_P(DcSctpSocketParametrizedTest, ExposesTheNumberOfNegotiatedStreams) {
2711   DcSctpOptions options_a = {
2712       .announced_maximum_incoming_streams = 12,
2713       .announced_maximum_outgoing_streams = 45,
2714   };
2715   SocketUnderTest a("A", options_a);
2716 
2717   DcSctpOptions options_z = {
2718       .announced_maximum_incoming_streams = 23,
2719       .announced_maximum_outgoing_streams = 34,
2720   };
2721   auto z = std::make_unique<SocketUnderTest>("Z", options_z);
2722 
2723   ConnectSockets(a, *z);
2724   z = MaybeHandoverSocket(std::move(z));
2725 
2726   ASSERT_HAS_VALUE_AND_ASSIGN(Metrics metrics_a, a.socket.GetMetrics());
2727   EXPECT_EQ(metrics_a.negotiated_maximum_incoming_streams, 12);
2728   EXPECT_EQ(metrics_a.negotiated_maximum_outgoing_streams, 23);
2729 
2730   ASSERT_HAS_VALUE_AND_ASSIGN(Metrics metrics_z, z->socket.GetMetrics());
2731   EXPECT_EQ(metrics_z.negotiated_maximum_incoming_streams, 23);
2732   EXPECT_EQ(metrics_z.negotiated_maximum_outgoing_streams, 12);
2733 }
2734 
TEST(DcSctpSocketTest,ResetStreamsDeferred)2735 TEST(DcSctpSocketTest, ResetStreamsDeferred) {
2736   // Guaranteed to be fragmented into two fragments.
2737   constexpr size_t kTwoFragmentsSize = DcSctpOptions::kMaxSafeMTUSize + 100;
2738 
2739   SocketUnderTest a("A");
2740   SocketUnderTest z("Z");
2741 
2742   ConnectSockets(a, z);
2743 
2744   a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
2745                               std::vector<uint8_t>(kTwoFragmentsSize)),
2746                 {});
2747   a.socket.Send(DcSctpMessage(StreamID(1), PPID(54),
2748                               std::vector<uint8_t>(kSmallMessageSize)),
2749                 {});
2750 
2751   a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2752 
2753   auto data1 = a.cb.ConsumeSentPacket();
2754   auto data2 = a.cb.ConsumeSentPacket();
2755   auto data3 = a.cb.ConsumeSentPacket();
2756   auto reconfig = a.cb.ConsumeSentPacket();
2757 
2758   EXPECT_THAT(data1, HasDataChunkWithSsn(SSN(0)));
2759   EXPECT_THAT(data2, HasDataChunkWithSsn(SSN(0)));
2760   EXPECT_THAT(data3, HasDataChunkWithSsn(SSN(1)));
2761   EXPECT_THAT(reconfig, HasReconfigWithStreams(ElementsAre(StreamID(1))));
2762 
2763   // Receive them slightly out of order to make stream resetting deferred.
2764   z.socket.ReceivePacket(reconfig);
2765 
2766   z.socket.ReceivePacket(data1);
2767   z.socket.ReceivePacket(data2);
2768   z.socket.ReceivePacket(data3);
2769 
2770   absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
2771   ASSERT_TRUE(msg1.has_value());
2772   EXPECT_EQ(msg1->stream_id(), StreamID(1));
2773   EXPECT_EQ(msg1->ppid(), PPID(53));
2774   EXPECT_EQ(msg1->payload().size(), kTwoFragmentsSize);
2775 
2776   absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
2777   ASSERT_TRUE(msg2.has_value());
2778   EXPECT_EQ(msg2->stream_id(), StreamID(1));
2779   EXPECT_EQ(msg2->ppid(), PPID(54));
2780   EXPECT_EQ(msg2->payload().size(), kSmallMessageSize);
2781 
2782   EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1))));
2783   ExchangeMessages(a, z);
2784 
2785   // Z sent "in progress", which will make A buffer packets until it's sure
2786   // that the reconfiguration has been applied. A will retry - wait for that.
2787   AdvanceTime(a, z, a.options.rto_initial);
2788 
2789   auto reconfig2 = a.cb.ConsumeSentPacket();
2790   EXPECT_THAT(reconfig2, HasReconfigWithStreams(ElementsAre(StreamID(1))));
2791   EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1))));
2792   z.socket.ReceivePacket(reconfig2);
2793 
2794   auto reconfig3 = z.cb.ConsumeSentPacket();
2795   EXPECT_THAT(reconfig3,
2796               HasReconfigWithResponse(
2797                   ReconfigurationResponseParameter::Result::kSuccessPerformed));
2798   a.socket.ReceivePacket(reconfig3);
2799 
2800   EXPECT_THAT(data1, HasDataChunkWithSsn(SSN(0)));
2801   EXPECT_THAT(data2, HasDataChunkWithSsn(SSN(0)));
2802   EXPECT_THAT(data3, HasDataChunkWithSsn(SSN(1)));
2803   EXPECT_THAT(reconfig, HasReconfigWithStreams(ElementsAre(StreamID(1))));
2804 
2805   // Send a new message after the stream has been reset.
2806   a.socket.Send(DcSctpMessage(StreamID(1), PPID(55),
2807                               std::vector<uint8_t>(kSmallMessageSize)),
2808                 {});
2809   ExchangeMessages(a, z);
2810 
2811   absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage();
2812   ASSERT_TRUE(msg3.has_value());
2813   EXPECT_EQ(msg3->stream_id(), StreamID(1));
2814   EXPECT_EQ(msg3->ppid(), PPID(55));
2815   EXPECT_EQ(msg3->payload().size(), kSmallMessageSize);
2816 }
2817 
TEST(DcSctpSocketTest,ResetStreamsWithPausedSenderResumesWhenPerformed)2818 TEST(DcSctpSocketTest, ResetStreamsWithPausedSenderResumesWhenPerformed) {
2819   SocketUnderTest a("A");
2820   SocketUnderTest z("Z");
2821 
2822   ConnectSockets(a, z);
2823 
2824   a.socket.Send(DcSctpMessage(StreamID(1), PPID(51),
2825                               std::vector<uint8_t>(kSmallMessageSize)),
2826                 {});
2827 
2828   a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
2829 
2830   // Will be queued, as the stream has an outstanding reset operation.
2831   a.socket.Send(DcSctpMessage(StreamID(1), PPID(52),
2832                               std::vector<uint8_t>(kSmallMessageSize)),
2833                 {});
2834 
2835   EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1))));
2836   EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1))));
2837   ExchangeMessages(a, z);
2838 
2839   absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
2840   ASSERT_TRUE(msg1.has_value());
2841   EXPECT_EQ(msg1->stream_id(), StreamID(1));
2842   EXPECT_EQ(msg1->ppid(), PPID(51));
2843   EXPECT_EQ(msg1->payload().size(), kSmallMessageSize);
2844 
2845   absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
2846   ASSERT_TRUE(msg2.has_value());
2847   EXPECT_EQ(msg2->stream_id(), StreamID(1));
2848   EXPECT_EQ(msg2->ppid(), PPID(52));
2849   EXPECT_EQ(msg2->payload().size(), kSmallMessageSize);
2850 }
2851 
2852 }  // namespace
2853 }  // namespace dcsctp
2854