1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_frame_parser.h"
6
7 #include <algorithm>
8 #include <ostream>
9 #include <utility>
10 #include <vector>
11
12 #include "base/check.h"
13 #include "base/check_op.h"
14 #include "base/logging.h"
15 #include "base/numerics/byte_conversions.h"
16 #include "net/websockets/websocket_frame.h"
17
18 namespace {
19
20 constexpr uint8_t kFinalBit = 0x80;
21 constexpr uint8_t kReserved1Bit = 0x40;
22 constexpr uint8_t kReserved2Bit = 0x20;
23 constexpr uint8_t kReserved3Bit = 0x10;
24 constexpr uint8_t kOpCodeMask = 0xF;
25 constexpr uint8_t kMaskBit = 0x80;
26 constexpr uint8_t kPayloadLengthMask = 0x7F;
27 constexpr uint64_t kMaxPayloadLengthWithoutExtendedLengthField = 125;
28 constexpr uint64_t kPayloadLengthWithTwoByteExtendedLengthField = 126;
29 constexpr uint64_t kPayloadLengthWithEightByteExtendedLengthField = 127;
30 constexpr size_t kMaximumFrameHeaderSize =
31 net::WebSocketFrameHeader::kBaseHeaderSize +
32 net::WebSocketFrameHeader::kMaximumExtendedLengthSize +
33 net::WebSocketFrameHeader::kMaskingKeyLength;
34
35 } // namespace.
36
37 namespace net {
38
39 WebSocketFrameParser::WebSocketFrameParser() = default;
40
41 WebSocketFrameParser::~WebSocketFrameParser() = default;
42
Decode(const char * data,size_t length,std::vector<std::unique_ptr<WebSocketFrameChunk>> * frame_chunks)43 bool WebSocketFrameParser::Decode(
44 const char* data,
45 size_t length,
46 std::vector<std::unique_ptr<WebSocketFrameChunk>>* frame_chunks) {
47 if (websocket_error_ != kWebSocketNormalClosure)
48 return false;
49 if (!length)
50 return true;
51
52 // TODO(crbug.com/40284755): This span construction can't be sound, the Decode
53 // method should be receiving a span, not a pointer and length.
54 auto data_span = UNSAFE_BUFFERS(base::span(data, length));
55 // If we have incomplete frame header, try to decode a header combining with
56 // |data|.
57 bool first_chunk = false;
58 if (incomplete_header_buffer_.size() > 0) {
59 DCHECK(!current_frame_header_.get());
60 const size_t original_size = incomplete_header_buffer_.size();
61 DCHECK_LE(original_size, kMaximumFrameHeaderSize);
62 incomplete_header_buffer_.insert(
63 incomplete_header_buffer_.end(), data,
64 data + std::min(length, kMaximumFrameHeaderSize - original_size));
65 const size_t consumed =
66 DecodeFrameHeader(base::as_byte_span(incomplete_header_buffer_));
67 if (websocket_error_ != kWebSocketNormalClosure)
68 return false;
69 if (!current_frame_header_.get())
70 return true;
71
72 DCHECK_GE(consumed, original_size);
73 data_span = data_span.subspan(consumed - original_size);
74 incomplete_header_buffer_.clear();
75 first_chunk = true;
76 }
77
78 DCHECK(incomplete_header_buffer_.empty());
79 while (data_span.size() > 0 || first_chunk) {
80 if (!current_frame_header_.get()) {
81 const size_t consumed = DecodeFrameHeader(base::as_bytes(data_span));
82 if (websocket_error_ != kWebSocketNormalClosure)
83 return false;
84 // If frame header is incomplete, then carry over the remaining
85 // data to the next round of Decode().
86 if (!current_frame_header_.get()) {
87 DCHECK(!consumed);
88 incomplete_header_buffer_.insert(incomplete_header_buffer_.end(),
89 data_span.data(),
90 data_span.data() + data_span.size());
91 // Sanity check: the size of carried-over data should not exceed
92 // the maximum possible length of a frame header.
93 DCHECK_LT(incomplete_header_buffer_.size(), kMaximumFrameHeaderSize);
94 return true;
95 }
96 DCHECK_GE(data_span.size(), consumed);
97 data_span = data_span.subspan(consumed);
98 first_chunk = true;
99 }
100 DCHECK(incomplete_header_buffer_.empty());
101 std::unique_ptr<WebSocketFrameChunk> frame_chunk =
102 DecodeFramePayload(first_chunk, &data_span);
103 first_chunk = false;
104 DCHECK(frame_chunk.get());
105 frame_chunks->push_back(std::move(frame_chunk));
106 }
107 return true;
108 }
109
DecodeFrameHeader(base::span<const uint8_t> data)110 size_t WebSocketFrameParser::DecodeFrameHeader(base::span<const uint8_t> data) {
111 DVLOG(3) << "DecodeFrameHeader buffer size:"
112 << ", data size:" << data.size();
113 typedef WebSocketFrameHeader::OpCode OpCode;
114 DCHECK(!current_frame_header_.get());
115
116 // Header needs 2 bytes at minimum.
117 if (data.size() < 2)
118 return 0;
119 size_t current = 0;
120 const uint8_t first_byte = data[current++];
121 const uint8_t second_byte = data[current++];
122
123 const bool final = (first_byte & kFinalBit) != 0;
124 const bool reserved1 = (first_byte & kReserved1Bit) != 0;
125 const bool reserved2 = (first_byte & kReserved2Bit) != 0;
126 const bool reserved3 = (first_byte & kReserved3Bit) != 0;
127 const OpCode opcode = first_byte & kOpCodeMask;
128
129 uint64_t payload_length = second_byte & kPayloadLengthMask;
130 if (payload_length == kPayloadLengthWithTwoByteExtendedLengthField) {
131 if (data.size() < current + 2)
132 return 0;
133 uint16_t payload_length_16 =
134 base::U16FromBigEndian(data.subspan(current).first<2>());
135 current += 2;
136 payload_length = payload_length_16;
137 if (payload_length <= kMaxPayloadLengthWithoutExtendedLengthField) {
138 websocket_error_ = kWebSocketErrorProtocolError;
139 return 0;
140 }
141 } else if (payload_length == kPayloadLengthWithEightByteExtendedLengthField) {
142 if (data.size() < current + 8)
143 return 0;
144 payload_length = base::U64FromBigEndian(data.subspan(current).first<8>());
145 current += 8;
146 if (payload_length <= UINT16_MAX ||
147 payload_length > static_cast<uint64_t>(INT64_MAX)) {
148 websocket_error_ = kWebSocketErrorProtocolError;
149 return 0;
150 }
151 if (payload_length > static_cast<uint64_t>(INT32_MAX)) {
152 websocket_error_ = kWebSocketErrorMessageTooBig;
153 return 0;
154 }
155 }
156 DCHECK_EQ(websocket_error_, kWebSocketNormalClosure);
157
158 WebSocketMaskingKey masking_key = {};
159 const bool masked = (second_byte & kMaskBit) != 0;
160 static constexpr int kMaskingKeyLength =
161 WebSocketFrameHeader::kMaskingKeyLength;
162 if (masked) {
163 if (data.size() < current + kMaskingKeyLength)
164 return 0;
165 std::copy(&data[current], &data[current] + kMaskingKeyLength,
166 masking_key.key);
167 current += kMaskingKeyLength;
168 }
169
170 current_frame_header_ = std::make_unique<WebSocketFrameHeader>(opcode);
171 current_frame_header_->final = final;
172 current_frame_header_->reserved1 = reserved1;
173 current_frame_header_->reserved2 = reserved2;
174 current_frame_header_->reserved3 = reserved3;
175 current_frame_header_->masked = masked;
176 current_frame_header_->masking_key = masking_key;
177 current_frame_header_->payload_length = payload_length;
178 DCHECK_EQ(0u, frame_offset_);
179 return current;
180 }
181
DecodeFramePayload(bool first_chunk,base::span<const char> * data)182 std::unique_ptr<WebSocketFrameChunk> WebSocketFrameParser::DecodeFramePayload(
183 bool first_chunk,
184 base::span<const char>* data) {
185 // The cast here is safe because |payload_length| is already checked to be
186 // less than std::numeric_limits<int>::max() when the header is parsed.
187 const int chunk_data_size = static_cast<int>(
188 std::min(static_cast<uint64_t>(data->size()),
189 current_frame_header_->payload_length - frame_offset_));
190
191 auto frame_chunk = std::make_unique<WebSocketFrameChunk>();
192 if (first_chunk) {
193 frame_chunk->header = current_frame_header_->Clone();
194 }
195 frame_chunk->final_chunk = false;
196 if (chunk_data_size > 0) {
197 frame_chunk->payload = data->subspan(0, chunk_data_size);
198 *data = data->subspan(chunk_data_size);
199 frame_offset_ += chunk_data_size;
200 }
201
202 DCHECK_LE(frame_offset_, current_frame_header_->payload_length);
203 if (frame_offset_ == current_frame_header_->payload_length) {
204 frame_chunk->final_chunk = true;
205 current_frame_header_.reset();
206 frame_offset_ = 0;
207 }
208
209 return frame_chunk;
210 }
211
212 } // namespace net
213