xref: /aosp_15_r20/external/cronet/net/websockets/websocket_deflate_stream_fuzzer.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2015 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <stddef.h>
6 #include <stdint.h>
7 
8 #include <fuzzer/FuzzedDataProvider.h>
9 
10 #include <string>
11 #include <vector>
12 
13 #include "base/check.h"
14 #include "base/memory/raw_ptr.h"
15 #include "base/memory/scoped_refptr.h"
16 #include "base/strings/string_number_conversions.h"
17 #include "base/strings/string_piece.h"
18 #include "net/base/io_buffer.h"
19 #include "net/base/net_errors.h"
20 #include "net/log/net_log_with_source.h"
21 #include "net/websockets/websocket_deflate_parameters.h"
22 #include "net/websockets/websocket_deflate_predictor.h"
23 #include "net/websockets/websocket_deflate_predictor_impl.h"
24 #include "net/websockets/websocket_deflate_stream.h"
25 #include "net/websockets/websocket_extension.h"
26 #include "net/websockets/websocket_frame.h"
27 #include "net/websockets/websocket_stream.h"
28 
29 namespace net {
30 
31 namespace {
32 
33 // If there are less random bytes left than MIN_BYTES_TO_CREATE_A_FRAME then
34 // CreateFrame() will always create an empty frame. Since the fuzzer can create
35 // the same empty frame with MIN_BYTES_TO_CREATE_A_FRAME bytes of input, save it
36 // from exploring a large space of ways to do the same thing.
37 constexpr size_t MIN_BYTES_TO_CREATE_A_FRAME = 3;
38 
39 constexpr size_t BYTES_CONSUMED_BY_PARAMS = 2;
40 
41 // If there are exactly BYTES_CONSUMED_BY_PARAMS + MIN_BYTES_TO_CREATE_A_FRAME
42 // bytes of input, then the fuzzer will test a single frame. In order to also
43 // test the case with zero frames, allow one less byte than this.
44 constexpr size_t MIN_USEFUL_SIZE =
45     BYTES_CONSUMED_BY_PARAMS + MIN_BYTES_TO_CREATE_A_FRAME - 1;
46 
47 class WebSocketFuzzedStream final : public WebSocketStream {
48  public:
WebSocketFuzzedStream(FuzzedDataProvider * fuzzed_data_provider)49   explicit WebSocketFuzzedStream(FuzzedDataProvider* fuzzed_data_provider)
50       : fuzzed_data_provider_(fuzzed_data_provider) {}
51 
ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)52   int ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
53                  CompletionOnceCallback callback) override {
54     if (fuzzed_data_provider_->remaining_bytes() < MIN_BYTES_TO_CREATE_A_FRAME)
55       return ERR_CONNECTION_CLOSED;
56     while (fuzzed_data_provider_->remaining_bytes() > 0)
57       frames->push_back(CreateFrame());
58     return OK;
59   }
60 
WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)61   int WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>>* frames,
62                   CompletionOnceCallback callback) override {
63     return ERR_FILE_NOT_FOUND;
64   }
65 
Close()66   void Close() override {}
GetSubProtocol() const67   std::string GetSubProtocol() const override { return std::string(); }
GetExtensions() const68   std::string GetExtensions() const override { return std::string(); }
GetNetLogWithSource() const69   const NetLogWithSource& GetNetLogWithSource() const override {
70     return net_log_;
71   }
72 
73  private:
CreateFrame()74   std::unique_ptr<WebSocketFrame> CreateFrame() {
75     WebSocketFrameHeader::OpCode opcode =
76         fuzzed_data_provider_
77             ->ConsumeIntegralInRange<WebSocketFrameHeader::OpCode>(
78                 WebSocketFrameHeader::kOpCodeContinuation,
79                 WebSocketFrameHeader::kOpCodeControlUnused);
80     auto frame = std::make_unique<WebSocketFrame>(opcode);
81     // Bad news: ConsumeBool actually consumes a whole byte per call, so do
82     // something hacky to conserve precious bits.
83     uint8_t flags = fuzzed_data_provider_->ConsumeIntegral<uint8_t>();
84     frame->header.final = flags & 0x1;
85     frame->header.reserved1 = (flags >> 1) & 0x1;
86     frame->header.reserved2 = (flags >> 2) & 0x1;
87     frame->header.reserved3 = (flags >> 3) & 0x1;
88     frame->header.masked = (flags >> 4) & 0x1;
89     uint64_t payload_length =
90         fuzzed_data_provider_->ConsumeIntegralInRange(0, 64);
91     std::vector<char> payload =
92         fuzzed_data_provider_->ConsumeBytes<char>(payload_length);
93     auto buffer = base::MakeRefCounted<IOBufferWithSize>(payload.size());
94     memcpy(buffer->data(), payload.data(), payload.size());
95     buffers_.push_back(buffer);
96     frame->payload = buffer->data();
97     frame->header.payload_length = payload.size();
98     return frame;
99   }
100 
101   std::vector<scoped_refptr<IOBufferWithSize>> buffers_;
102 
103   raw_ptr<FuzzedDataProvider> fuzzed_data_provider_;
104 
105   NetLogWithSource net_log_;
106 };
107 
WebSocketDeflateStreamFuzz(const uint8_t * data,size_t size)108 void WebSocketDeflateStreamFuzz(const uint8_t* data, size_t size) {
109   FuzzedDataProvider fuzzed_data_provider(data, size);
110   uint8_t flags = fuzzed_data_provider.ConsumeIntegral<uint8_t>();
111   bool server_no_context_takeover = flags & 0x1;
112   bool client_no_context_takeover = (flags >> 1) & 0x1;
113   uint8_t window_bits = fuzzed_data_provider.ConsumeIntegral<uint8_t>();
114   int server_max_window_bits = (window_bits & 0x7) + 8;
115   int client_max_window_bits = ((window_bits >> 3) & 0x7) + 8;
116   // WebSocketDeflateStream needs to be constructed on each call because it
117   // has state.
118   WebSocketExtension params("permessage-deflate");
119   if (server_no_context_takeover)
120     params.Add(WebSocketExtension::Parameter("server_no_context_takeover"));
121   if (client_no_context_takeover)
122     params.Add(WebSocketExtension::Parameter("client_no_context_takeover"));
123   params.Add(WebSocketExtension::Parameter(
124       "server_max_window_bits", base::NumberToString(server_max_window_bits)));
125   params.Add(WebSocketExtension::Parameter(
126       "client_max_window_bits", base::NumberToString(client_max_window_bits)));
127   std::string failure_message;
128   WebSocketDeflateParameters parameters;
129   DCHECK(parameters.Initialize(params, &failure_message)) << failure_message;
130   WebSocketDeflateStream deflate_stream(
131       std::make_unique<WebSocketFuzzedStream>(&fuzzed_data_provider),
132       parameters, std::make_unique<WebSocketDeflatePredictorImpl>());
133   std::vector<std::unique_ptr<net::WebSocketFrame>> frames;
134   deflate_stream.ReadFrames(&frames, CompletionOnceCallback());
135 }
136 
137 }  // namespace
138 
139 }  // namespace net
140 
141 // Entry point for LibFuzzer.
LLVMFuzzerTestOneInput(const uint8_t * data,size_t size)142 extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
143   if (size < net::MIN_USEFUL_SIZE)
144     return 0;
145   net::WebSocketDeflateStreamFuzz(data, size);
146 
147   return 0;
148 }
149