xref: /aosp_15_r20/external/webrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "net/dcsctp/fuzzers/dcsctp_fuzzers.h"
11 
12 #include <string>
13 #include <utility>
14 #include <vector>
15 
16 #include "net/dcsctp/common/math.h"
17 #include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
18 #include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
19 #include "net/dcsctp/packet/chunk/data_chunk.h"
20 #include "net/dcsctp/packet/chunk/forward_tsn_chunk.h"
21 #include "net/dcsctp/packet/chunk/forward_tsn_common.h"
22 #include "net/dcsctp/packet/chunk/shutdown_chunk.h"
23 #include "net/dcsctp/packet/error_cause/protocol_violation_cause.h"
24 #include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h"
25 #include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h"
26 #include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
27 #include "net/dcsctp/packet/parameter/state_cookie_parameter.h"
28 #include "net/dcsctp/public/dcsctp_message.h"
29 #include "net/dcsctp/public/types.h"
30 #include "net/dcsctp/socket/dcsctp_socket.h"
31 #include "net/dcsctp/socket/state_cookie.h"
32 #include "rtc_base/logging.h"
33 
34 namespace dcsctp {
35 namespace dcsctp_fuzzers {
36 namespace {
37 static constexpr int kRandomValue = FuzzerCallbacks::kRandomValue;
38 static constexpr size_t kMinInputLength = 5;
39 static constexpr size_t kMaxInputLength = 1024;
40 
41 // A starting state for the socket, when fuzzing.
42 enum class StartingState : int {
43   kConnectNotCalled,
44   // When socket initiating Connect
45   kConnectCalled,
46   kReceivedInitAck,
47   kReceivedCookieAck,
48   // When socket initiating Shutdown
49   kShutdownCalled,
50   kReceivedShutdownAck,
51   // When peer socket initiated Connect
52   kReceivedInit,
53   kReceivedCookieEcho,
54   // When peer initiated Shutdown
55   kReceivedShutdown,
56   kReceivedShutdownComplete,
57   kNumberOfStates,
58 };
59 
60 // State about the current fuzzing iteration
61 class FuzzState {
62  public:
FuzzState(rtc::ArrayView<const uint8_t> data)63   explicit FuzzState(rtc::ArrayView<const uint8_t> data) : data_(data) {}
64 
GetByte()65   uint8_t GetByte() {
66     uint8_t value = 0;
67     if (offset_ < data_.size()) {
68       value = data_[offset_];
69       ++offset_;
70     }
71     return value;
72   }
73 
GetNextTSN()74   TSN GetNextTSN() { return TSN(tsn_++); }
GetNextMID()75   MID GetNextMID() { return MID(mid_++); }
76 
empty() const77   bool empty() const { return offset_ >= data_.size(); }
78 
79  private:
80   uint32_t tsn_ = kRandomValue;
81   uint32_t mid_ = 0;
82   rtc::ArrayView<const uint8_t> data_;
83   size_t offset_ = 0;
84 };
85 
SetSocketState(DcSctpSocketInterface & socket,FuzzerCallbacks & socket_cb,StartingState state)86 void SetSocketState(DcSctpSocketInterface& socket,
87                     FuzzerCallbacks& socket_cb,
88                     StartingState state) {
89   // We'll use another temporary peer socket for the establishment.
90   FuzzerCallbacks peer_cb;
91   DcSctpSocket peer("peer", peer_cb, nullptr, {});
92 
93   switch (state) {
94     case StartingState::kConnectNotCalled:
95       return;
96     case StartingState::kConnectCalled:
97       socket.Connect();
98       return;
99     case StartingState::kReceivedInitAck:
100       socket.Connect();
101       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // INIT
102       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // INIT_ACK
103       return;
104     case StartingState::kReceivedCookieAck:
105       socket.Connect();
106       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // INIT
107       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // INIT_ACK
108       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // COOKIE_ECHO
109       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // COOKIE_ACK
110       return;
111     case StartingState::kShutdownCalled:
112       socket.Connect();
113       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // INIT
114       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // INIT_ACK
115       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // COOKIE_ECHO
116       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // COOKIE_ACK
117       socket.Shutdown();
118       return;
119     case StartingState::kReceivedShutdownAck:
120       socket.Connect();
121       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // INIT
122       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // INIT_ACK
123       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // COOKIE_ECHO
124       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // COOKIE_ACK
125       socket.Shutdown();
126       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // SHUTDOWN
127       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // SHUTDOWN_ACK
128       return;
129     case StartingState::kReceivedInit:
130       peer.Connect();
131       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // INIT
132       return;
133     case StartingState::kReceivedCookieEcho:
134       peer.Connect();
135       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // INIT
136       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // INIT_ACK
137       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // COOKIE_ECHO
138       return;
139     case StartingState::kReceivedShutdown:
140       socket.Connect();
141       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // INIT
142       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // INIT_ACK
143       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // COOKIE_ECHO
144       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // COOKIE_ACK
145       peer.Shutdown();
146       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // SHUTDOWN
147       return;
148     case StartingState::kReceivedShutdownComplete:
149       socket.Connect();
150       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // INIT
151       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // INIT_ACK
152       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // COOKIE_ECHO
153       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // COOKIE_ACK
154       peer.Shutdown();
155       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // SHUTDOWN
156       peer.ReceivePacket(socket_cb.ConsumeSentPacket());  // SHUTDOWN_ACK
157       socket.ReceivePacket(peer_cb.ConsumeSentPacket());  // SHUTDOWN_COMPLETE
158       return;
159     case StartingState::kNumberOfStates:
160       RTC_CHECK(false);
161       return;
162   }
163 }
164 
MakeDataChunk(FuzzState & state,SctpPacket::Builder & b)165 void MakeDataChunk(FuzzState& state, SctpPacket::Builder& b) {
166   DataChunk::Options options;
167   options.is_unordered = IsUnordered(state.GetByte() != 0);
168   options.is_beginning = Data::IsBeginning(state.GetByte() != 0);
169   options.is_end = Data::IsEnd(state.GetByte() != 0);
170   b.Add(DataChunk(state.GetNextTSN(), StreamID(state.GetByte()),
171                   SSN(state.GetByte()), PPID(53), std::vector<uint8_t>(10),
172                   options));
173 }
174 
MakeInitChunk(FuzzState & state,SctpPacket::Builder & b)175 void MakeInitChunk(FuzzState& state, SctpPacket::Builder& b) {
176   Parameters::Builder builder;
177   builder.Add(ForwardTsnSupportedParameter());
178 
179   b.Add(InitChunk(VerificationTag(kRandomValue), 10000, 1000, 1000,
180                   TSN(kRandomValue), builder.Build()));
181 }
182 
MakeInitAckChunk(FuzzState & state,SctpPacket::Builder & b)183 void MakeInitAckChunk(FuzzState& state, SctpPacket::Builder& b) {
184   Parameters::Builder builder;
185   builder.Add(ForwardTsnSupportedParameter());
186 
187   uint8_t state_cookie[] = {1, 2, 3, 4, 5};
188   Parameters::Builder params_builder =
189       Parameters::Builder().Add(StateCookieParameter(state_cookie));
190 
191   b.Add(InitAckChunk(VerificationTag(kRandomValue), 10000, 1000, 1000,
192                      TSN(kRandomValue), builder.Build()));
193 }
194 
MakeSackChunk(FuzzState & state,SctpPacket::Builder & b)195 void MakeSackChunk(FuzzState& state, SctpPacket::Builder& b) {
196   std::vector<SackChunk::GapAckBlock> gap_ack_blocks;
197   uint16_t last_end = 0;
198   while (gap_ack_blocks.size() < 20) {
199     uint8_t delta_start = state.GetByte();
200     if (delta_start < 0x80) {
201       break;
202     }
203     uint8_t delta_end = state.GetByte();
204 
205     uint16_t start = last_end + delta_start;
206     uint16_t end = start + delta_end;
207     last_end = end;
208     gap_ack_blocks.emplace_back(start, end);
209   }
210 
211   TSN cum_ack_tsn(kRandomValue + state.GetByte());
212   b.Add(SackChunk(cum_ack_tsn, 10000, std::move(gap_ack_blocks), {}));
213 }
214 
MakeHeartbeatRequestChunk(FuzzState & state,SctpPacket::Builder & b)215 void MakeHeartbeatRequestChunk(FuzzState& state, SctpPacket::Builder& b) {
216   uint8_t info[] = {1, 2, 3, 4, 5};
217   b.Add(HeartbeatRequestChunk(
218       Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build()));
219 }
220 
MakeHeartbeatAckChunk(FuzzState & state,SctpPacket::Builder & b)221 void MakeHeartbeatAckChunk(FuzzState& state, SctpPacket::Builder& b) {
222   std::vector<uint8_t> info(8);
223   b.Add(HeartbeatRequestChunk(
224       Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build()));
225 }
226 
MakeAbortChunk(FuzzState & state,SctpPacket::Builder & b)227 void MakeAbortChunk(FuzzState& state, SctpPacket::Builder& b) {
228   b.Add(AbortChunk(
229       /*filled_in_verification_tag=*/true,
230       Parameters::Builder().Add(UserInitiatedAbortCause("Fuzzing")).Build()));
231 }
232 
MakeErrorChunk(FuzzState & state,SctpPacket::Builder & b)233 void MakeErrorChunk(FuzzState& state, SctpPacket::Builder& b) {
234   b.Add(ErrorChunk(
235       Parameters::Builder().Add(ProtocolViolationCause("Fuzzing")).Build()));
236 }
237 
MakeCookieEchoChunk(FuzzState & state,SctpPacket::Builder & b)238 void MakeCookieEchoChunk(FuzzState& state, SctpPacket::Builder& b) {
239   std::vector<uint8_t> cookie(StateCookie::kCookieSize);
240   b.Add(CookieEchoChunk(cookie));
241 }
242 
MakeCookieAckChunk(FuzzState & state,SctpPacket::Builder & b)243 void MakeCookieAckChunk(FuzzState& state, SctpPacket::Builder& b) {
244   b.Add(CookieAckChunk());
245 }
246 
MakeShutdownChunk(FuzzState & state,SctpPacket::Builder & b)247 void MakeShutdownChunk(FuzzState& state, SctpPacket::Builder& b) {
248   b.Add(ShutdownChunk(state.GetNextTSN()));
249 }
250 
MakeShutdownAckChunk(FuzzState & state,SctpPacket::Builder & b)251 void MakeShutdownAckChunk(FuzzState& state, SctpPacket::Builder& b) {
252   b.Add(ShutdownAckChunk());
253 }
254 
MakeShutdownCompleteChunk(FuzzState & state,SctpPacket::Builder & b)255 void MakeShutdownCompleteChunk(FuzzState& state, SctpPacket::Builder& b) {
256   b.Add(ShutdownCompleteChunk(false));
257 }
258 
MakeReConfigChunk(FuzzState & state,SctpPacket::Builder & b)259 void MakeReConfigChunk(FuzzState& state, SctpPacket::Builder& b) {
260   std::vector<StreamID> streams = {StreamID(state.GetByte())};
261   Parameters::Builder params_builder =
262       Parameters::Builder().Add(OutgoingSSNResetRequestParameter(
263           ReconfigRequestSN(kRandomValue), ReconfigRequestSN(kRandomValue),
264           state.GetNextTSN(), streams));
265   b.Add(ReConfigChunk(params_builder.Build()));
266 }
267 
MakeForwardTsnChunk(FuzzState & state,SctpPacket::Builder & b)268 void MakeForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) {
269   std::vector<ForwardTsnChunk::SkippedStream> skipped_streams;
270   for (;;) {
271     uint8_t stream = state.GetByte();
272     if (skipped_streams.size() > 20 || stream < 0x80) {
273       break;
274     }
275     skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte()));
276   }
277   b.Add(ForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams)));
278 }
279 
MakeIDataChunk(FuzzState & state,SctpPacket::Builder & b)280 void MakeIDataChunk(FuzzState& state, SctpPacket::Builder& b) {
281   DataChunk::Options options;
282   options.is_unordered = IsUnordered(state.GetByte() != 0);
283   options.is_beginning = Data::IsBeginning(state.GetByte() != 0);
284   options.is_end = Data::IsEnd(state.GetByte() != 0);
285   b.Add(IDataChunk(state.GetNextTSN(), StreamID(state.GetByte()),
286                    state.GetNextMID(), PPID(53), FSN(0),
287                    std::vector<uint8_t>(10), options));
288 }
289 
MakeIForwardTsnChunk(FuzzState & state,SctpPacket::Builder & b)290 void MakeIForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) {
291   std::vector<ForwardTsnChunk::SkippedStream> skipped_streams;
292   for (;;) {
293     uint8_t stream = state.GetByte();
294     if (skipped_streams.size() > 20 || stream < 0x80) {
295       break;
296     }
297     skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte()));
298   }
299   b.Add(IForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams)));
300 }
301 
302 class RandomFuzzedChunk : public Chunk {
303  public:
RandomFuzzedChunk(FuzzState & state)304   explicit RandomFuzzedChunk(FuzzState& state) : state_(state) {}
305 
SerializeTo(std::vector<uint8_t> & out) const306   void SerializeTo(std::vector<uint8_t>& out) const override {
307     size_t bytes = state_.GetByte();
308     for (size_t i = 0; i < bytes; ++i) {
309       out.push_back(state_.GetByte());
310     }
311   }
312 
ToString() const313   std::string ToString() const override { return std::string("RANDOM_FUZZED"); }
314 
315  private:
316   FuzzState& state_;
317 };
318 
MakeChunkWithRandomContent(FuzzState & state,SctpPacket::Builder & b)319 void MakeChunkWithRandomContent(FuzzState& state, SctpPacket::Builder& b) {
320   b.Add(RandomFuzzedChunk(state));
321 }
322 
GeneratePacket(FuzzState & state)323 std::vector<uint8_t> GeneratePacket(FuzzState& state) {
324   DcSctpOptions options;
325   // Setting a fixed limit to not be dependent on the defaults, which may
326   // change.
327   options.mtu = 2048;
328   SctpPacket::Builder builder(VerificationTag(kRandomValue), options);
329 
330   // The largest expected serialized chunk, as created by fuzzers.
331   static constexpr size_t kMaxChunkSize = 256;
332 
333   for (int i = 0; i < 5 && builder.bytes_remaining() > kMaxChunkSize; ++i) {
334     switch (state.GetByte()) {
335       case 1:
336         MakeDataChunk(state, builder);
337         break;
338       case 2:
339         MakeInitChunk(state, builder);
340         break;
341       case 3:
342         MakeInitAckChunk(state, builder);
343         break;
344       case 4:
345         MakeSackChunk(state, builder);
346         break;
347       case 5:
348         MakeHeartbeatRequestChunk(state, builder);
349         break;
350       case 6:
351         MakeHeartbeatAckChunk(state, builder);
352         break;
353       case 7:
354         MakeAbortChunk(state, builder);
355         break;
356       case 8:
357         MakeErrorChunk(state, builder);
358         break;
359       case 9:
360         MakeCookieEchoChunk(state, builder);
361         break;
362       case 10:
363         MakeCookieAckChunk(state, builder);
364         break;
365       case 11:
366         MakeShutdownChunk(state, builder);
367         break;
368       case 12:
369         MakeShutdownAckChunk(state, builder);
370         break;
371       case 13:
372         MakeShutdownCompleteChunk(state, builder);
373         break;
374       case 14:
375         MakeReConfigChunk(state, builder);
376         break;
377       case 15:
378         MakeForwardTsnChunk(state, builder);
379         break;
380       case 16:
381         MakeIDataChunk(state, builder);
382         break;
383       case 17:
384         MakeIForwardTsnChunk(state, builder);
385         break;
386       case 18:
387         MakeChunkWithRandomContent(state, builder);
388         break;
389       default:
390         break;
391     }
392   }
393   std::vector<uint8_t> packet = builder.Build();
394   return packet;
395 }
396 }  // namespace
397 
FuzzSocket(DcSctpSocketInterface & socket,FuzzerCallbacks & cb,rtc::ArrayView<const uint8_t> data)398 void FuzzSocket(DcSctpSocketInterface& socket,
399                 FuzzerCallbacks& cb,
400                 rtc::ArrayView<const uint8_t> data) {
401   if (data.size() < kMinInputLength || data.size() > kMaxInputLength) {
402     return;
403   }
404   if (data[0] >= static_cast<int>(StartingState::kNumberOfStates)) {
405     return;
406   }
407 
408   // Set the socket in a specified valid starting state
409   SetSocketState(socket, cb, static_cast<StartingState>(data[0]));
410 
411   FuzzState state(data.subview(1));
412 
413   while (!state.empty()) {
414     switch (state.GetByte()) {
415       case 1:
416         // Generate a valid SCTP packet (based on fuzz data) and "receive it".
417         socket.ReceivePacket(GeneratePacket(state));
418         break;
419       case 2:
420         socket.Connect();
421         break;
422       case 3:
423         socket.Shutdown();
424         break;
425       case 4:
426         socket.Close();
427         break;
428       case 5: {
429         StreamID streams[] = {StreamID(state.GetByte())};
430         socket.ResetStreams(streams);
431       } break;
432       case 6: {
433         uint8_t flags = state.GetByte();
434         SendOptions options;
435         options.unordered = IsUnordered(flags & 0x01);
436         options.max_retransmissions =
437             (flags & 0x02) != 0 ? absl::make_optional(0) : absl::nullopt;
438         options.lifecycle_id = LifecycleId(42);
439         size_t payload_exponent = (flags >> 2) % 16;
440         size_t payload_size = static_cast<size_t>(1) << payload_exponent;
441         socket.Send(DcSctpMessage(StreamID(state.GetByte()), PPID(53),
442                                   std::vector<uint8_t>(payload_size)),
443                     options);
444         break;
445       }
446       case 7: {
447         // Expire an active timeout/timer.
448         uint8_t timeout_idx = state.GetByte();
449         absl::optional<TimeoutID> timeout_id = cb.ExpireTimeout(timeout_idx);
450         if (timeout_id.has_value()) {
451           socket.HandleTimeout(*timeout_id);
452         }
453         break;
454       }
455       default:
456         break;
457     }
458   }
459 }
460 }  // namespace dcsctp_fuzzers
461 }  // namespace dcsctp
462