xref: /aosp_15_r20/external/cronet/net/websockets/websocket_frame_parser.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_frame_parser.h"
6 
7 #include <algorithm>
8 #include <ostream>
9 #include <utility>
10 #include <vector>
11 
12 #include "base/check.h"
13 #include "base/check_op.h"
14 #include "base/logging.h"
15 #include "base/numerics/byte_conversions.h"
16 #include "net/websockets/websocket_frame.h"
17 
18 namespace {
19 
20 constexpr uint8_t kFinalBit = 0x80;
21 constexpr uint8_t kReserved1Bit = 0x40;
22 constexpr uint8_t kReserved2Bit = 0x20;
23 constexpr uint8_t kReserved3Bit = 0x10;
24 constexpr uint8_t kOpCodeMask = 0xF;
25 constexpr uint8_t kMaskBit = 0x80;
26 constexpr uint8_t kPayloadLengthMask = 0x7F;
27 constexpr uint64_t kMaxPayloadLengthWithoutExtendedLengthField = 125;
28 constexpr uint64_t kPayloadLengthWithTwoByteExtendedLengthField = 126;
29 constexpr uint64_t kPayloadLengthWithEightByteExtendedLengthField = 127;
30 constexpr size_t kMaximumFrameHeaderSize =
31     net::WebSocketFrameHeader::kBaseHeaderSize +
32     net::WebSocketFrameHeader::kMaximumExtendedLengthSize +
33     net::WebSocketFrameHeader::kMaskingKeyLength;
34 
35 }  // namespace.
36 
37 namespace net {
38 
39 WebSocketFrameParser::WebSocketFrameParser() = default;
40 
41 WebSocketFrameParser::~WebSocketFrameParser() = default;
42 
Decode(const char * data,size_t length,std::vector<std::unique_ptr<WebSocketFrameChunk>> * frame_chunks)43 bool WebSocketFrameParser::Decode(
44     const char* data,
45     size_t length,
46     std::vector<std::unique_ptr<WebSocketFrameChunk>>* frame_chunks) {
47   if (websocket_error_ != kWebSocketNormalClosure)
48     return false;
49   if (!length)
50     return true;
51 
52   // TODO(crbug.com/40284755): This span construction can't be sound, the Decode
53   // method should be receiving a span, not a pointer and length.
54   auto data_span = UNSAFE_BUFFERS(base::span(data, length));
55   // If we have incomplete frame header, try to decode a header combining with
56   // |data|.
57   bool first_chunk = false;
58   if (incomplete_header_buffer_.size() > 0) {
59     DCHECK(!current_frame_header_.get());
60     const size_t original_size = incomplete_header_buffer_.size();
61     DCHECK_LE(original_size, kMaximumFrameHeaderSize);
62     incomplete_header_buffer_.insert(
63         incomplete_header_buffer_.end(), data,
64         data + std::min(length, kMaximumFrameHeaderSize - original_size));
65     const size_t consumed =
66         DecodeFrameHeader(base::as_byte_span(incomplete_header_buffer_));
67     if (websocket_error_ != kWebSocketNormalClosure)
68       return false;
69     if (!current_frame_header_.get())
70       return true;
71 
72     DCHECK_GE(consumed, original_size);
73     data_span = data_span.subspan(consumed - original_size);
74     incomplete_header_buffer_.clear();
75     first_chunk = true;
76   }
77 
78   DCHECK(incomplete_header_buffer_.empty());
79   while (data_span.size() > 0 || first_chunk) {
80     if (!current_frame_header_.get()) {
81       const size_t consumed = DecodeFrameHeader(base::as_bytes(data_span));
82       if (websocket_error_ != kWebSocketNormalClosure)
83         return false;
84       // If frame header is incomplete, then carry over the remaining
85       // data to the next round of Decode().
86       if (!current_frame_header_.get()) {
87         DCHECK(!consumed);
88         incomplete_header_buffer_.insert(incomplete_header_buffer_.end(),
89                                          data_span.data(),
90                                          data_span.data() + data_span.size());
91         // Sanity check: the size of carried-over data should not exceed
92         // the maximum possible length of a frame header.
93         DCHECK_LT(incomplete_header_buffer_.size(), kMaximumFrameHeaderSize);
94         return true;
95       }
96       DCHECK_GE(data_span.size(), consumed);
97       data_span = data_span.subspan(consumed);
98       first_chunk = true;
99     }
100     DCHECK(incomplete_header_buffer_.empty());
101     std::unique_ptr<WebSocketFrameChunk> frame_chunk =
102         DecodeFramePayload(first_chunk, &data_span);
103     first_chunk = false;
104     DCHECK(frame_chunk.get());
105     frame_chunks->push_back(std::move(frame_chunk));
106   }
107   return true;
108 }
109 
DecodeFrameHeader(base::span<const uint8_t> data)110 size_t WebSocketFrameParser::DecodeFrameHeader(base::span<const uint8_t> data) {
111   DVLOG(3) << "DecodeFrameHeader buffer size:"
112            << ", data size:" << data.size();
113   typedef WebSocketFrameHeader::OpCode OpCode;
114   DCHECK(!current_frame_header_.get());
115 
116   // Header needs 2 bytes at minimum.
117   if (data.size() < 2)
118     return 0;
119   size_t current = 0;
120   const uint8_t first_byte = data[current++];
121   const uint8_t second_byte = data[current++];
122 
123   const bool final = (first_byte & kFinalBit) != 0;
124   const bool reserved1 = (first_byte & kReserved1Bit) != 0;
125   const bool reserved2 = (first_byte & kReserved2Bit) != 0;
126   const bool reserved3 = (first_byte & kReserved3Bit) != 0;
127   const OpCode opcode = first_byte & kOpCodeMask;
128 
129   uint64_t payload_length = second_byte & kPayloadLengthMask;
130   if (payload_length == kPayloadLengthWithTwoByteExtendedLengthField) {
131     if (data.size() < current + 2)
132       return 0;
133     uint16_t payload_length_16 =
134         base::U16FromBigEndian(data.subspan(current).first<2>());
135     current += 2;
136     payload_length = payload_length_16;
137     if (payload_length <= kMaxPayloadLengthWithoutExtendedLengthField) {
138       websocket_error_ = kWebSocketErrorProtocolError;
139       return 0;
140     }
141   } else if (payload_length == kPayloadLengthWithEightByteExtendedLengthField) {
142     if (data.size() < current + 8)
143       return 0;
144     payload_length = base::U64FromBigEndian(data.subspan(current).first<8>());
145     current += 8;
146     if (payload_length <= UINT16_MAX ||
147         payload_length > static_cast<uint64_t>(INT64_MAX)) {
148       websocket_error_ = kWebSocketErrorProtocolError;
149       return 0;
150     }
151     if (payload_length > static_cast<uint64_t>(INT32_MAX)) {
152       websocket_error_ = kWebSocketErrorMessageTooBig;
153       return 0;
154     }
155   }
156   DCHECK_EQ(websocket_error_, kWebSocketNormalClosure);
157 
158   WebSocketMaskingKey masking_key = {};
159   const bool masked = (second_byte & kMaskBit) != 0;
160   static constexpr int kMaskingKeyLength =
161       WebSocketFrameHeader::kMaskingKeyLength;
162   if (masked) {
163     if (data.size() < current + kMaskingKeyLength)
164       return 0;
165     std::copy(&data[current], &data[current] + kMaskingKeyLength,
166               masking_key.key);
167     current += kMaskingKeyLength;
168   }
169 
170   current_frame_header_ = std::make_unique<WebSocketFrameHeader>(opcode);
171   current_frame_header_->final = final;
172   current_frame_header_->reserved1 = reserved1;
173   current_frame_header_->reserved2 = reserved2;
174   current_frame_header_->reserved3 = reserved3;
175   current_frame_header_->masked = masked;
176   current_frame_header_->masking_key = masking_key;
177   current_frame_header_->payload_length = payload_length;
178   DCHECK_EQ(0u, frame_offset_);
179   return current;
180 }
181 
DecodeFramePayload(bool first_chunk,base::span<const char> * data)182 std::unique_ptr<WebSocketFrameChunk> WebSocketFrameParser::DecodeFramePayload(
183     bool first_chunk,
184     base::span<const char>* data) {
185   // The cast here is safe because |payload_length| is already checked to be
186   // less than std::numeric_limits<int>::max() when the header is parsed.
187   const int chunk_data_size = static_cast<int>(
188       std::min(static_cast<uint64_t>(data->size()),
189                current_frame_header_->payload_length - frame_offset_));
190 
191   auto frame_chunk = std::make_unique<WebSocketFrameChunk>();
192   if (first_chunk) {
193     frame_chunk->header = current_frame_header_->Clone();
194   }
195   frame_chunk->final_chunk = false;
196   if (chunk_data_size > 0) {
197     frame_chunk->payload = data->subspan(0, chunk_data_size);
198     *data = data->subspan(chunk_data_size);
199     frame_offset_ += chunk_data_size;
200   }
201 
202   DCHECK_LE(frame_offset_, current_frame_header_->payload_length);
203   if (frame_offset_ == current_frame_header_->payload_length) {
204     frame_chunk->final_chunk = true;
205     current_frame_header_.reset();
206     frame_offset_ = 0;
207   }
208 
209   return frame_chunk;
210 }
211 
212 }  // namespace net
213