// Copyright 2012 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_frame_parser.h" #include #include #include #include #include "base/check.h" #include "base/check_op.h" #include "base/logging.h" #include "base/numerics/byte_conversions.h" #include "net/websockets/websocket_frame.h" namespace { constexpr uint8_t kFinalBit = 0x80; constexpr uint8_t kReserved1Bit = 0x40; constexpr uint8_t kReserved2Bit = 0x20; constexpr uint8_t kReserved3Bit = 0x10; constexpr uint8_t kOpCodeMask = 0xF; constexpr uint8_t kMaskBit = 0x80; constexpr uint8_t kPayloadLengthMask = 0x7F; constexpr uint64_t kMaxPayloadLengthWithoutExtendedLengthField = 125; constexpr uint64_t kPayloadLengthWithTwoByteExtendedLengthField = 126; constexpr uint64_t kPayloadLengthWithEightByteExtendedLengthField = 127; constexpr size_t kMaximumFrameHeaderSize = net::WebSocketFrameHeader::kBaseHeaderSize + net::WebSocketFrameHeader::kMaximumExtendedLengthSize + net::WebSocketFrameHeader::kMaskingKeyLength; } // namespace. namespace net { WebSocketFrameParser::WebSocketFrameParser() = default; WebSocketFrameParser::~WebSocketFrameParser() = default; bool WebSocketFrameParser::Decode( const char* data, size_t length, std::vector>* frame_chunks) { if (websocket_error_ != kWebSocketNormalClosure) return false; if (!length) return true; // TODO(crbug.com/40284755): This span construction can't be sound, the Decode // method should be receiving a span, not a pointer and length. auto data_span = UNSAFE_BUFFERS(base::span(data, length)); // If we have incomplete frame header, try to decode a header combining with // |data|. bool first_chunk = false; if (incomplete_header_buffer_.size() > 0) { DCHECK(!current_frame_header_.get()); const size_t original_size = incomplete_header_buffer_.size(); DCHECK_LE(original_size, kMaximumFrameHeaderSize); incomplete_header_buffer_.insert( incomplete_header_buffer_.end(), data, data + std::min(length, kMaximumFrameHeaderSize - original_size)); const size_t consumed = DecodeFrameHeader(base::as_byte_span(incomplete_header_buffer_)); if (websocket_error_ != kWebSocketNormalClosure) return false; if (!current_frame_header_.get()) return true; DCHECK_GE(consumed, original_size); data_span = data_span.subspan(consumed - original_size); incomplete_header_buffer_.clear(); first_chunk = true; } DCHECK(incomplete_header_buffer_.empty()); while (data_span.size() > 0 || first_chunk) { if (!current_frame_header_.get()) { const size_t consumed = DecodeFrameHeader(base::as_bytes(data_span)); if (websocket_error_ != kWebSocketNormalClosure) return false; // If frame header is incomplete, then carry over the remaining // data to the next round of Decode(). if (!current_frame_header_.get()) { DCHECK(!consumed); incomplete_header_buffer_.insert(incomplete_header_buffer_.end(), data_span.data(), data_span.data() + data_span.size()); // Sanity check: the size of carried-over data should not exceed // the maximum possible length of a frame header. DCHECK_LT(incomplete_header_buffer_.size(), kMaximumFrameHeaderSize); return true; } DCHECK_GE(data_span.size(), consumed); data_span = data_span.subspan(consumed); first_chunk = true; } DCHECK(incomplete_header_buffer_.empty()); std::unique_ptr frame_chunk = DecodeFramePayload(first_chunk, &data_span); first_chunk = false; DCHECK(frame_chunk.get()); frame_chunks->push_back(std::move(frame_chunk)); } return true; } size_t WebSocketFrameParser::DecodeFrameHeader(base::span data) { DVLOG(3) << "DecodeFrameHeader buffer size:" << ", data size:" << data.size(); typedef WebSocketFrameHeader::OpCode OpCode; DCHECK(!current_frame_header_.get()); // Header needs 2 bytes at minimum. if (data.size() < 2) return 0; size_t current = 0; const uint8_t first_byte = data[current++]; const uint8_t second_byte = data[current++]; const bool final = (first_byte & kFinalBit) != 0; const bool reserved1 = (first_byte & kReserved1Bit) != 0; const bool reserved2 = (first_byte & kReserved2Bit) != 0; const bool reserved3 = (first_byte & kReserved3Bit) != 0; const OpCode opcode = first_byte & kOpCodeMask; uint64_t payload_length = second_byte & kPayloadLengthMask; if (payload_length == kPayloadLengthWithTwoByteExtendedLengthField) { if (data.size() < current + 2) return 0; uint16_t payload_length_16 = base::U16FromBigEndian(data.subspan(current).first<2>()); current += 2; payload_length = payload_length_16; if (payload_length <= kMaxPayloadLengthWithoutExtendedLengthField) { websocket_error_ = kWebSocketErrorProtocolError; return 0; } } else if (payload_length == kPayloadLengthWithEightByteExtendedLengthField) { if (data.size() < current + 8) return 0; payload_length = base::U64FromBigEndian(data.subspan(current).first<8>()); current += 8; if (payload_length <= UINT16_MAX || payload_length > static_cast(INT64_MAX)) { websocket_error_ = kWebSocketErrorProtocolError; return 0; } if (payload_length > static_cast(INT32_MAX)) { websocket_error_ = kWebSocketErrorMessageTooBig; return 0; } } DCHECK_EQ(websocket_error_, kWebSocketNormalClosure); WebSocketMaskingKey masking_key = {}; const bool masked = (second_byte & kMaskBit) != 0; static constexpr int kMaskingKeyLength = WebSocketFrameHeader::kMaskingKeyLength; if (masked) { if (data.size() < current + kMaskingKeyLength) return 0; std::copy(&data[current], &data[current] + kMaskingKeyLength, masking_key.key); current += kMaskingKeyLength; } current_frame_header_ = std::make_unique(opcode); current_frame_header_->final = final; current_frame_header_->reserved1 = reserved1; current_frame_header_->reserved2 = reserved2; current_frame_header_->reserved3 = reserved3; current_frame_header_->masked = masked; current_frame_header_->masking_key = masking_key; current_frame_header_->payload_length = payload_length; DCHECK_EQ(0u, frame_offset_); return current; } std::unique_ptr WebSocketFrameParser::DecodeFramePayload( bool first_chunk, base::span* data) { // The cast here is safe because |payload_length| is already checked to be // less than std::numeric_limits::max() when the header is parsed. const int chunk_data_size = static_cast( std::min(static_cast(data->size()), current_frame_header_->payload_length - frame_offset_)); auto frame_chunk = std::make_unique(); if (first_chunk) { frame_chunk->header = current_frame_header_->Clone(); } frame_chunk->final_chunk = false; if (chunk_data_size > 0) { frame_chunk->payload = data->subspan(0, chunk_data_size); *data = data->subspan(chunk_data_size); frame_offset_ += chunk_data_size; } DCHECK_LE(frame_offset_, current_frame_header_->payload_length); if (frame_offset_ == current_frame_header_->payload_length) { frame_chunk->final_chunk = true; current_frame_header_.reset(); frame_offset_ = 0; } return frame_chunk; } } // namespace net