1 /*
2 * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10 #include "net/dcsctp/fuzzers/dcsctp_fuzzers.h"
11
12 #include <string>
13 #include <utility>
14 #include <vector>
15
16 #include "net/dcsctp/common/math.h"
17 #include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
18 #include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
19 #include "net/dcsctp/packet/chunk/data_chunk.h"
20 #include "net/dcsctp/packet/chunk/forward_tsn_chunk.h"
21 #include "net/dcsctp/packet/chunk/forward_tsn_common.h"
22 #include "net/dcsctp/packet/chunk/shutdown_chunk.h"
23 #include "net/dcsctp/packet/error_cause/protocol_violation_cause.h"
24 #include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h"
25 #include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h"
26 #include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
27 #include "net/dcsctp/packet/parameter/state_cookie_parameter.h"
28 #include "net/dcsctp/public/dcsctp_message.h"
29 #include "net/dcsctp/public/types.h"
30 #include "net/dcsctp/socket/dcsctp_socket.h"
31 #include "net/dcsctp/socket/state_cookie.h"
32 #include "rtc_base/logging.h"
33
34 namespace dcsctp {
35 namespace dcsctp_fuzzers {
36 namespace {
37 static constexpr int kRandomValue = FuzzerCallbacks::kRandomValue;
38 static constexpr size_t kMinInputLength = 5;
39 static constexpr size_t kMaxInputLength = 1024;
40
41 // A starting state for the socket, when fuzzing.
42 enum class StartingState : int {
43 kConnectNotCalled,
44 // When socket initiating Connect
45 kConnectCalled,
46 kReceivedInitAck,
47 kReceivedCookieAck,
48 // When socket initiating Shutdown
49 kShutdownCalled,
50 kReceivedShutdownAck,
51 // When peer socket initiated Connect
52 kReceivedInit,
53 kReceivedCookieEcho,
54 // When peer initiated Shutdown
55 kReceivedShutdown,
56 kReceivedShutdownComplete,
57 kNumberOfStates,
58 };
59
60 // State about the current fuzzing iteration
61 class FuzzState {
62 public:
FuzzState(rtc::ArrayView<const uint8_t> data)63 explicit FuzzState(rtc::ArrayView<const uint8_t> data) : data_(data) {}
64
GetByte()65 uint8_t GetByte() {
66 uint8_t value = 0;
67 if (offset_ < data_.size()) {
68 value = data_[offset_];
69 ++offset_;
70 }
71 return value;
72 }
73
GetNextTSN()74 TSN GetNextTSN() { return TSN(tsn_++); }
GetNextMID()75 MID GetNextMID() { return MID(mid_++); }
76
empty() const77 bool empty() const { return offset_ >= data_.size(); }
78
79 private:
80 uint32_t tsn_ = kRandomValue;
81 uint32_t mid_ = 0;
82 rtc::ArrayView<const uint8_t> data_;
83 size_t offset_ = 0;
84 };
85
SetSocketState(DcSctpSocketInterface & socket,FuzzerCallbacks & socket_cb,StartingState state)86 void SetSocketState(DcSctpSocketInterface& socket,
87 FuzzerCallbacks& socket_cb,
88 StartingState state) {
89 // We'll use another temporary peer socket for the establishment.
90 FuzzerCallbacks peer_cb;
91 DcSctpSocket peer("peer", peer_cb, nullptr, {});
92
93 switch (state) {
94 case StartingState::kConnectNotCalled:
95 return;
96 case StartingState::kConnectCalled:
97 socket.Connect();
98 return;
99 case StartingState::kReceivedInitAck:
100 socket.Connect();
101 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
102 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
103 return;
104 case StartingState::kReceivedCookieAck:
105 socket.Connect();
106 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
107 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
108 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
109 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
110 return;
111 case StartingState::kShutdownCalled:
112 socket.Connect();
113 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
114 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
115 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
116 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
117 socket.Shutdown();
118 return;
119 case StartingState::kReceivedShutdownAck:
120 socket.Connect();
121 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
122 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
123 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
124 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
125 socket.Shutdown();
126 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN
127 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_ACK
128 return;
129 case StartingState::kReceivedInit:
130 peer.Connect();
131 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT
132 return;
133 case StartingState::kReceivedCookieEcho:
134 peer.Connect();
135 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT
136 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT_ACK
137 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ECHO
138 return;
139 case StartingState::kReceivedShutdown:
140 socket.Connect();
141 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
142 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
143 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
144 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
145 peer.Shutdown();
146 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN
147 return;
148 case StartingState::kReceivedShutdownComplete:
149 socket.Connect();
150 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
151 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
152 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
153 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
154 peer.Shutdown();
155 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN
156 peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN_ACK
157 socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_COMPLETE
158 return;
159 case StartingState::kNumberOfStates:
160 RTC_CHECK(false);
161 return;
162 }
163 }
164
MakeDataChunk(FuzzState & state,SctpPacket::Builder & b)165 void MakeDataChunk(FuzzState& state, SctpPacket::Builder& b) {
166 DataChunk::Options options;
167 options.is_unordered = IsUnordered(state.GetByte() != 0);
168 options.is_beginning = Data::IsBeginning(state.GetByte() != 0);
169 options.is_end = Data::IsEnd(state.GetByte() != 0);
170 b.Add(DataChunk(state.GetNextTSN(), StreamID(state.GetByte()),
171 SSN(state.GetByte()), PPID(53), std::vector<uint8_t>(10),
172 options));
173 }
174
MakeInitChunk(FuzzState & state,SctpPacket::Builder & b)175 void MakeInitChunk(FuzzState& state, SctpPacket::Builder& b) {
176 Parameters::Builder builder;
177 builder.Add(ForwardTsnSupportedParameter());
178
179 b.Add(InitChunk(VerificationTag(kRandomValue), 10000, 1000, 1000,
180 TSN(kRandomValue), builder.Build()));
181 }
182
MakeInitAckChunk(FuzzState & state,SctpPacket::Builder & b)183 void MakeInitAckChunk(FuzzState& state, SctpPacket::Builder& b) {
184 Parameters::Builder builder;
185 builder.Add(ForwardTsnSupportedParameter());
186
187 uint8_t state_cookie[] = {1, 2, 3, 4, 5};
188 Parameters::Builder params_builder =
189 Parameters::Builder().Add(StateCookieParameter(state_cookie));
190
191 b.Add(InitAckChunk(VerificationTag(kRandomValue), 10000, 1000, 1000,
192 TSN(kRandomValue), builder.Build()));
193 }
194
MakeSackChunk(FuzzState & state,SctpPacket::Builder & b)195 void MakeSackChunk(FuzzState& state, SctpPacket::Builder& b) {
196 std::vector<SackChunk::GapAckBlock> gap_ack_blocks;
197 uint16_t last_end = 0;
198 while (gap_ack_blocks.size() < 20) {
199 uint8_t delta_start = state.GetByte();
200 if (delta_start < 0x80) {
201 break;
202 }
203 uint8_t delta_end = state.GetByte();
204
205 uint16_t start = last_end + delta_start;
206 uint16_t end = start + delta_end;
207 last_end = end;
208 gap_ack_blocks.emplace_back(start, end);
209 }
210
211 TSN cum_ack_tsn(kRandomValue + state.GetByte());
212 b.Add(SackChunk(cum_ack_tsn, 10000, std::move(gap_ack_blocks), {}));
213 }
214
MakeHeartbeatRequestChunk(FuzzState & state,SctpPacket::Builder & b)215 void MakeHeartbeatRequestChunk(FuzzState& state, SctpPacket::Builder& b) {
216 uint8_t info[] = {1, 2, 3, 4, 5};
217 b.Add(HeartbeatRequestChunk(
218 Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build()));
219 }
220
MakeHeartbeatAckChunk(FuzzState & state,SctpPacket::Builder & b)221 void MakeHeartbeatAckChunk(FuzzState& state, SctpPacket::Builder& b) {
222 std::vector<uint8_t> info(8);
223 b.Add(HeartbeatRequestChunk(
224 Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build()));
225 }
226
MakeAbortChunk(FuzzState & state,SctpPacket::Builder & b)227 void MakeAbortChunk(FuzzState& state, SctpPacket::Builder& b) {
228 b.Add(AbortChunk(
229 /*filled_in_verification_tag=*/true,
230 Parameters::Builder().Add(UserInitiatedAbortCause("Fuzzing")).Build()));
231 }
232
MakeErrorChunk(FuzzState & state,SctpPacket::Builder & b)233 void MakeErrorChunk(FuzzState& state, SctpPacket::Builder& b) {
234 b.Add(ErrorChunk(
235 Parameters::Builder().Add(ProtocolViolationCause("Fuzzing")).Build()));
236 }
237
MakeCookieEchoChunk(FuzzState & state,SctpPacket::Builder & b)238 void MakeCookieEchoChunk(FuzzState& state, SctpPacket::Builder& b) {
239 std::vector<uint8_t> cookie(StateCookie::kCookieSize);
240 b.Add(CookieEchoChunk(cookie));
241 }
242
MakeCookieAckChunk(FuzzState & state,SctpPacket::Builder & b)243 void MakeCookieAckChunk(FuzzState& state, SctpPacket::Builder& b) {
244 b.Add(CookieAckChunk());
245 }
246
MakeShutdownChunk(FuzzState & state,SctpPacket::Builder & b)247 void MakeShutdownChunk(FuzzState& state, SctpPacket::Builder& b) {
248 b.Add(ShutdownChunk(state.GetNextTSN()));
249 }
250
MakeShutdownAckChunk(FuzzState & state,SctpPacket::Builder & b)251 void MakeShutdownAckChunk(FuzzState& state, SctpPacket::Builder& b) {
252 b.Add(ShutdownAckChunk());
253 }
254
MakeShutdownCompleteChunk(FuzzState & state,SctpPacket::Builder & b)255 void MakeShutdownCompleteChunk(FuzzState& state, SctpPacket::Builder& b) {
256 b.Add(ShutdownCompleteChunk(false));
257 }
258
MakeReConfigChunk(FuzzState & state,SctpPacket::Builder & b)259 void MakeReConfigChunk(FuzzState& state, SctpPacket::Builder& b) {
260 std::vector<StreamID> streams = {StreamID(state.GetByte())};
261 Parameters::Builder params_builder =
262 Parameters::Builder().Add(OutgoingSSNResetRequestParameter(
263 ReconfigRequestSN(kRandomValue), ReconfigRequestSN(kRandomValue),
264 state.GetNextTSN(), streams));
265 b.Add(ReConfigChunk(params_builder.Build()));
266 }
267
MakeForwardTsnChunk(FuzzState & state,SctpPacket::Builder & b)268 void MakeForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) {
269 std::vector<ForwardTsnChunk::SkippedStream> skipped_streams;
270 for (;;) {
271 uint8_t stream = state.GetByte();
272 if (skipped_streams.size() > 20 || stream < 0x80) {
273 break;
274 }
275 skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte()));
276 }
277 b.Add(ForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams)));
278 }
279
MakeIDataChunk(FuzzState & state,SctpPacket::Builder & b)280 void MakeIDataChunk(FuzzState& state, SctpPacket::Builder& b) {
281 DataChunk::Options options;
282 options.is_unordered = IsUnordered(state.GetByte() != 0);
283 options.is_beginning = Data::IsBeginning(state.GetByte() != 0);
284 options.is_end = Data::IsEnd(state.GetByte() != 0);
285 b.Add(IDataChunk(state.GetNextTSN(), StreamID(state.GetByte()),
286 state.GetNextMID(), PPID(53), FSN(0),
287 std::vector<uint8_t>(10), options));
288 }
289
MakeIForwardTsnChunk(FuzzState & state,SctpPacket::Builder & b)290 void MakeIForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) {
291 std::vector<ForwardTsnChunk::SkippedStream> skipped_streams;
292 for (;;) {
293 uint8_t stream = state.GetByte();
294 if (skipped_streams.size() > 20 || stream < 0x80) {
295 break;
296 }
297 skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte()));
298 }
299 b.Add(IForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams)));
300 }
301
302 class RandomFuzzedChunk : public Chunk {
303 public:
RandomFuzzedChunk(FuzzState & state)304 explicit RandomFuzzedChunk(FuzzState& state) : state_(state) {}
305
SerializeTo(std::vector<uint8_t> & out) const306 void SerializeTo(std::vector<uint8_t>& out) const override {
307 size_t bytes = state_.GetByte();
308 for (size_t i = 0; i < bytes; ++i) {
309 out.push_back(state_.GetByte());
310 }
311 }
312
ToString() const313 std::string ToString() const override { return std::string("RANDOM_FUZZED"); }
314
315 private:
316 FuzzState& state_;
317 };
318
MakeChunkWithRandomContent(FuzzState & state,SctpPacket::Builder & b)319 void MakeChunkWithRandomContent(FuzzState& state, SctpPacket::Builder& b) {
320 b.Add(RandomFuzzedChunk(state));
321 }
322
GeneratePacket(FuzzState & state)323 std::vector<uint8_t> GeneratePacket(FuzzState& state) {
324 DcSctpOptions options;
325 // Setting a fixed limit to not be dependent on the defaults, which may
326 // change.
327 options.mtu = 2048;
328 SctpPacket::Builder builder(VerificationTag(kRandomValue), options);
329
330 // The largest expected serialized chunk, as created by fuzzers.
331 static constexpr size_t kMaxChunkSize = 256;
332
333 for (int i = 0; i < 5 && builder.bytes_remaining() > kMaxChunkSize; ++i) {
334 switch (state.GetByte()) {
335 case 1:
336 MakeDataChunk(state, builder);
337 break;
338 case 2:
339 MakeInitChunk(state, builder);
340 break;
341 case 3:
342 MakeInitAckChunk(state, builder);
343 break;
344 case 4:
345 MakeSackChunk(state, builder);
346 break;
347 case 5:
348 MakeHeartbeatRequestChunk(state, builder);
349 break;
350 case 6:
351 MakeHeartbeatAckChunk(state, builder);
352 break;
353 case 7:
354 MakeAbortChunk(state, builder);
355 break;
356 case 8:
357 MakeErrorChunk(state, builder);
358 break;
359 case 9:
360 MakeCookieEchoChunk(state, builder);
361 break;
362 case 10:
363 MakeCookieAckChunk(state, builder);
364 break;
365 case 11:
366 MakeShutdownChunk(state, builder);
367 break;
368 case 12:
369 MakeShutdownAckChunk(state, builder);
370 break;
371 case 13:
372 MakeShutdownCompleteChunk(state, builder);
373 break;
374 case 14:
375 MakeReConfigChunk(state, builder);
376 break;
377 case 15:
378 MakeForwardTsnChunk(state, builder);
379 break;
380 case 16:
381 MakeIDataChunk(state, builder);
382 break;
383 case 17:
384 MakeIForwardTsnChunk(state, builder);
385 break;
386 case 18:
387 MakeChunkWithRandomContent(state, builder);
388 break;
389 default:
390 break;
391 }
392 }
393 std::vector<uint8_t> packet = builder.Build();
394 return packet;
395 }
396 } // namespace
397
FuzzSocket(DcSctpSocketInterface & socket,FuzzerCallbacks & cb,rtc::ArrayView<const uint8_t> data)398 void FuzzSocket(DcSctpSocketInterface& socket,
399 FuzzerCallbacks& cb,
400 rtc::ArrayView<const uint8_t> data) {
401 if (data.size() < kMinInputLength || data.size() > kMaxInputLength) {
402 return;
403 }
404 if (data[0] >= static_cast<int>(StartingState::kNumberOfStates)) {
405 return;
406 }
407
408 // Set the socket in a specified valid starting state
409 SetSocketState(socket, cb, static_cast<StartingState>(data[0]));
410
411 FuzzState state(data.subview(1));
412
413 while (!state.empty()) {
414 switch (state.GetByte()) {
415 case 1:
416 // Generate a valid SCTP packet (based on fuzz data) and "receive it".
417 socket.ReceivePacket(GeneratePacket(state));
418 break;
419 case 2:
420 socket.Connect();
421 break;
422 case 3:
423 socket.Shutdown();
424 break;
425 case 4:
426 socket.Close();
427 break;
428 case 5: {
429 StreamID streams[] = {StreamID(state.GetByte())};
430 socket.ResetStreams(streams);
431 } break;
432 case 6: {
433 uint8_t flags = state.GetByte();
434 SendOptions options;
435 options.unordered = IsUnordered(flags & 0x01);
436 options.max_retransmissions =
437 (flags & 0x02) != 0 ? absl::make_optional(0) : absl::nullopt;
438 options.lifecycle_id = LifecycleId(42);
439 size_t payload_exponent = (flags >> 2) % 16;
440 size_t payload_size = static_cast<size_t>(1) << payload_exponent;
441 socket.Send(DcSctpMessage(StreamID(state.GetByte()), PPID(53),
442 std::vector<uint8_t>(payload_size)),
443 options);
444 break;
445 }
446 case 7: {
447 // Expire an active timeout/timer.
448 uint8_t timeout_idx = state.GetByte();
449 absl::optional<TimeoutID> timeout_id = cb.ExpireTimeout(timeout_idx);
450 if (timeout_id.has_value()) {
451 socket.HandleTimeout(*timeout_id);
452 }
453 break;
454 }
455 default:
456 break;
457 }
458 }
459 }
460 } // namespace dcsctp_fuzzers
461 } // namespace dcsctp
462