xref: /aosp_15_r20/external/webrtc/net/dcsctp/socket/heartbeat_handler_test.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "net/dcsctp/socket/heartbeat_handler.h"
11 
12 #include <memory>
13 #include <utility>
14 #include <vector>
15 
16 #include "api/task_queue/task_queue_base.h"
17 #include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
18 #include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
19 #include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h"
20 #include "net/dcsctp/public/types.h"
21 #include "net/dcsctp/socket/mock_context.h"
22 #include "net/dcsctp/testing/testing_macros.h"
23 #include "rtc_base/gunit.h"
24 #include "test/gmock.h"
25 
26 namespace dcsctp {
27 namespace {
28 using ::testing::ElementsAre;
29 using ::testing::IsEmpty;
30 using ::testing::NiceMock;
31 using ::testing::Return;
32 using ::testing::SizeIs;
33 
34 constexpr DurationMs kHeartbeatInterval = DurationMs(30'000);
35 
MakeOptions(DurationMs heartbeat_interval)36 DcSctpOptions MakeOptions(DurationMs heartbeat_interval) {
37   DcSctpOptions options;
38   options.heartbeat_interval_include_rtt = false;
39   options.heartbeat_interval = heartbeat_interval;
40   return options;
41 }
42 
43 class HeartbeatHandlerTestBase : public testing::Test {
44  protected:
HeartbeatHandlerTestBase(DurationMs heartbeat_interval)45   explicit HeartbeatHandlerTestBase(DurationMs heartbeat_interval)
46       : options_(MakeOptions(heartbeat_interval)),
47         context_(&callbacks_),
48         timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) {
49           return callbacks_.CreateTimeout(precision);
50         }),
51         handler_("log: ", options_, &context_, &timer_manager_) {}
52 
AdvanceTime(DurationMs duration)53   void AdvanceTime(DurationMs duration) {
54     callbacks_.AdvanceTime(duration);
55     for (;;) {
56       absl::optional<TimeoutID> timeout_id = callbacks_.GetNextExpiredTimeout();
57       if (!timeout_id.has_value()) {
58         break;
59       }
60       timer_manager_.HandleTimeout(*timeout_id);
61     }
62   }
63 
64   const DcSctpOptions options_;
65   NiceMock<MockDcSctpSocketCallbacks> callbacks_;
66   NiceMock<MockContext> context_;
67   TimerManager timer_manager_;
68   HeartbeatHandler handler_;
69 };
70 
71 class HeartbeatHandlerTest : public HeartbeatHandlerTestBase {
72  protected:
HeartbeatHandlerTest()73   HeartbeatHandlerTest() : HeartbeatHandlerTestBase(kHeartbeatInterval) {}
74 };
75 
76 class DisabledHeartbeatHandlerTest : public HeartbeatHandlerTestBase {
77  protected:
DisabledHeartbeatHandlerTest()78   DisabledHeartbeatHandlerTest() : HeartbeatHandlerTestBase(DurationMs(0)) {}
79 };
80 
TEST_F(HeartbeatHandlerTest,HasRunningHeartbeatIntervalTimer)81 TEST_F(HeartbeatHandlerTest, HasRunningHeartbeatIntervalTimer) {
82   AdvanceTime(options_.heartbeat_interval);
83 
84   // Validate that a heartbeat request was sent.
85   std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
86   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
87   ASSERT_THAT(packet.descriptors(), SizeIs(1));
88 
89   ASSERT_HAS_VALUE_AND_ASSIGN(
90       HeartbeatRequestChunk request,
91       HeartbeatRequestChunk::Parse(packet.descriptors()[0].data));
92 
93   EXPECT_TRUE(request.info().has_value());
94 }
95 
TEST_F(HeartbeatHandlerTest,RepliesToHeartbeatRequests)96 TEST_F(HeartbeatHandlerTest, RepliesToHeartbeatRequests) {
97   uint8_t info_data[] = {1, 2, 3, 4, 5};
98   HeartbeatRequestChunk request(
99       Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build());
100 
101   handler_.HandleHeartbeatRequest(std::move(request));
102 
103   std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
104   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
105   ASSERT_THAT(packet.descriptors(), SizeIs(1));
106 
107   ASSERT_HAS_VALUE_AND_ASSIGN(
108       HeartbeatAckChunk response,
109       HeartbeatAckChunk::Parse(packet.descriptors()[0].data));
110 
111   ASSERT_HAS_VALUE_AND_ASSIGN(
112       HeartbeatInfoParameter param,
113       response.parameters().get<HeartbeatInfoParameter>());
114 
115   EXPECT_THAT(param.info(), ElementsAre(1, 2, 3, 4, 5));
116 }
117 
TEST_F(HeartbeatHandlerTest,SendsHeartbeatRequestsOnIdleChannel)118 TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) {
119   AdvanceTime(options_.heartbeat_interval);
120 
121   // Grab the request, and make a response.
122   std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
123   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
124   ASSERT_THAT(packet.descriptors(), SizeIs(1));
125 
126   ASSERT_HAS_VALUE_AND_ASSIGN(
127       HeartbeatRequestChunk req,
128       HeartbeatRequestChunk::Parse(packet.descriptors()[0].data));
129 
130   HeartbeatAckChunk ack(std::move(req).extract_parameters());
131 
132   // Respond a while later. This RTT will be measured by the handler
133   constexpr DurationMs rtt(313);
134 
135   EXPECT_CALL(context_, ObserveRTT(rtt)).Times(1);
136 
137   callbacks_.AdvanceTime(rtt);
138   handler_.HandleHeartbeatAck(std::move(ack));
139 }
140 
TEST_F(HeartbeatHandlerTest,DoesntObserveInvalidHeartbeats)141 TEST_F(HeartbeatHandlerTest, DoesntObserveInvalidHeartbeats) {
142   AdvanceTime(options_.heartbeat_interval);
143 
144   // Grab the request, and make a response.
145   std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
146   ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
147   ASSERT_THAT(packet.descriptors(), SizeIs(1));
148 
149   ASSERT_HAS_VALUE_AND_ASSIGN(
150       HeartbeatRequestChunk req,
151       HeartbeatRequestChunk::Parse(packet.descriptors()[0].data));
152 
153   HeartbeatAckChunk ack(std::move(req).extract_parameters());
154 
155   EXPECT_CALL(context_, ObserveRTT).Times(0);
156 
157   // Go backwards in time - which make the HEARTBEAT-ACK have an invalid
158   // timestamp in it, as it will be in the future.
159   callbacks_.AdvanceTime(DurationMs(-100));
160 
161   handler_.HandleHeartbeatAck(std::move(ack));
162 }
163 
TEST_F(HeartbeatHandlerTest,IncreasesErrorIfNotAckedInTime)164 TEST_F(HeartbeatHandlerTest, IncreasesErrorIfNotAckedInTime) {
165   DurationMs rto(105);
166   EXPECT_CALL(context_, current_rto).WillOnce(Return(rto));
167   AdvanceTime(options_.heartbeat_interval);
168 
169   // Validate that a request was sent.
170   EXPECT_THAT(callbacks_.ConsumeSentPacket(), Not(IsEmpty()));
171 
172   EXPECT_CALL(context_, IncrementTxErrorCounter).Times(1);
173   AdvanceTime(rto);
174 }
175 
TEST_F(DisabledHeartbeatHandlerTest,IsReallyDisabled)176 TEST_F(DisabledHeartbeatHandlerTest, IsReallyDisabled) {
177   AdvanceTime(options_.heartbeat_interval);
178 
179   // Validate that a request was NOT sent.
180   EXPECT_THAT(callbacks_.ConsumeSentPacket(), IsEmpty());
181 }
182 
183 }  // namespace
184 }  // namespace dcsctp
185