xref: /aosp_15_r20/external/cronet/net/server/web_socket_encoder.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/server/web_socket_encoder.h"
6 
7 #include <limits>
8 #include <string_view>
9 #include <utility>
10 
11 #include "base/check.h"
12 #include "base/memory/ptr_util.h"
13 #include "base/strings/strcat.h"
14 #include "base/strings/string_number_conversions.h"
15 #include "net/base/io_buffer.h"
16 #include "net/websockets/websocket_deflate_parameters.h"
17 #include "net/websockets/websocket_extension.h"
18 #include "net/websockets/websocket_extension_parser.h"
19 #include "net/websockets/websocket_frame.h"
20 
21 namespace net {
22 
23 const char WebSocketEncoder::kClientExtensions[] =
24     "permessage-deflate; client_max_window_bits";
25 
26 namespace {
27 
28 const int kInflaterChunkSize = 16 * 1024;
29 
30 // Constants for hybi-10 frame format.
31 
32 const unsigned char kFinalBit = 0x80;
33 const unsigned char kReserved1Bit = 0x40;
34 const unsigned char kReserved2Bit = 0x20;
35 const unsigned char kReserved3Bit = 0x10;
36 const unsigned char kOpCodeMask = 0xF;
37 const unsigned char kMaskBit = 0x80;
38 const unsigned char kPayloadLengthMask = 0x7F;
39 
40 const size_t kMaxSingleBytePayloadLength = 125;
41 const size_t kTwoBytePayloadLengthField = 126;
42 const size_t kEightBytePayloadLengthField = 127;
43 const size_t kMaskingKeyWidthInBytes = 4;
44 
DecodeFrameHybi17(std::string_view frame,bool client_frame,int * bytes_consumed,std::string * output,bool * compressed)45 WebSocket::ParseResult DecodeFrameHybi17(std::string_view frame,
46                                          bool client_frame,
47                                          int* bytes_consumed,
48                                          std::string* output,
49                                          bool* compressed) {
50   size_t data_length = frame.length();
51   if (data_length < 2)
52     return WebSocket::FRAME_INCOMPLETE;
53 
54   const char* buffer_begin = const_cast<char*>(frame.data());
55   const char* p = buffer_begin;
56   const char* buffer_end = p + data_length;
57 
58   unsigned char first_byte = *p++;
59   unsigned char second_byte = *p++;
60 
61   bool final = (first_byte & kFinalBit) != 0;
62   bool reserved1 = (first_byte & kReserved1Bit) != 0;
63   bool reserved2 = (first_byte & kReserved2Bit) != 0;
64   bool reserved3 = (first_byte & kReserved3Bit) != 0;
65   int op_code = first_byte & kOpCodeMask;
66   bool masked = (second_byte & kMaskBit) != 0;
67   *compressed = reserved1;
68   if (reserved2 || reserved3)
69     return WebSocket::FRAME_ERROR;  // Only compression extension is supported.
70 
71   bool closed = false;
72   switch (op_code) {
73     case WebSocketFrameHeader::OpCodeEnum::kOpCodeClose:
74       closed = true;
75       break;
76 
77     case WebSocketFrameHeader::OpCodeEnum::kOpCodeText:
78     case WebSocketFrameHeader::OpCodeEnum::
79         kOpCodeContinuation:  // Treated in the same as kOpCodeText.
80     case WebSocketFrameHeader::OpCodeEnum::kOpCodePing:
81     case WebSocketFrameHeader::OpCodeEnum::kOpCodePong:
82       break;
83 
84     case WebSocketFrameHeader::OpCodeEnum::kOpCodeBinary:  // We don't support
85                                                            // binary frames yet.
86     default:
87       return WebSocket::FRAME_ERROR;
88   }
89 
90   if (client_frame && !masked)  // In Hybi-17 spec client MUST mask its frame.
91     return WebSocket::FRAME_ERROR;
92 
93   uint64_t payload_length64 = second_byte & kPayloadLengthMask;
94   if (payload_length64 > kMaxSingleBytePayloadLength) {
95     int extended_payload_length_size;
96     if (payload_length64 == kTwoBytePayloadLengthField) {
97       extended_payload_length_size = 2;
98     } else {
99       DCHECK(payload_length64 == kEightBytePayloadLengthField);
100       extended_payload_length_size = 8;
101     }
102     if (buffer_end - p < extended_payload_length_size)
103       return WebSocket::FRAME_INCOMPLETE;
104     payload_length64 = 0;
105     for (int i = 0; i < extended_payload_length_size; ++i) {
106       payload_length64 <<= 8;
107       payload_length64 |= static_cast<unsigned char>(*p++);
108     }
109   }
110 
111   size_t actual_masking_key_length = masked ? kMaskingKeyWidthInBytes : 0;
112   static const uint64_t max_payload_length = 0x7FFFFFFFFFFFFFFFull;
113   static size_t max_length = std::numeric_limits<size_t>::max();
114   if (payload_length64 > max_payload_length ||
115       payload_length64 + actual_masking_key_length > max_length) {
116     // WebSocket frame length too large.
117     return WebSocket::FRAME_ERROR;
118   }
119   size_t payload_length = static_cast<size_t>(payload_length64);
120 
121   size_t total_length = actual_masking_key_length + payload_length;
122   if (static_cast<size_t>(buffer_end - p) < total_length)
123     return WebSocket::FRAME_INCOMPLETE;
124 
125   if (masked) {
126     output->resize(payload_length);
127     const char* masking_key = p;
128     char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes);
129     for (size_t i = 0; i < payload_length; ++i)  // Unmask the payload.
130       (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes];
131   } else {
132     output->assign(p, p + payload_length);
133   }
134 
135   size_t pos = p + actual_masking_key_length + payload_length - buffer_begin;
136   *bytes_consumed = pos;
137 
138   if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePing)
139     return WebSocket::FRAME_PING;
140 
141   if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePong)
142     return WebSocket::FRAME_PONG;
143 
144   if (closed)
145     return WebSocket::FRAME_CLOSE;
146 
147   return final ? WebSocket::FRAME_OK_FINAL : WebSocket::FRAME_OK_MIDDLE;
148 }
149 
EncodeFrameHybi17(std::string_view message,int masking_key,bool compressed,WebSocketFrameHeader::OpCodeEnum op_code,std::string * output)150 void EncodeFrameHybi17(std::string_view message,
151                        int masking_key,
152                        bool compressed,
153                        WebSocketFrameHeader::OpCodeEnum op_code,
154                        std::string* output) {
155   std::vector<char> frame;
156   size_t data_length = message.length();
157 
158   int reserved1 = compressed ? kReserved1Bit : 0;
159   frame.push_back(kFinalBit | op_code | reserved1);
160   char mask_key_bit = masking_key != 0 ? kMaskBit : 0;
161   if (data_length <= kMaxSingleBytePayloadLength) {
162     frame.push_back(static_cast<char>(data_length) | mask_key_bit);
163   } else if (data_length <= 0xFFFF) {
164     frame.push_back(kTwoBytePayloadLengthField | mask_key_bit);
165     frame.push_back((data_length & 0xFF00) >> 8);
166     frame.push_back(data_length & 0xFF);
167   } else {
168     frame.push_back(kEightBytePayloadLengthField | mask_key_bit);
169     char extended_payload_length[8];
170     size_t remaining = data_length;
171     // Fill the length into extended_payload_length in the network byte order.
172     for (int i = 0; i < 8; ++i) {
173       extended_payload_length[7 - i] = remaining & 0xFF;
174       remaining >>= 8;
175     }
176     frame.insert(frame.end(), extended_payload_length,
177                  extended_payload_length + 8);
178     DCHECK(!remaining);
179   }
180 
181   const char* data = const_cast<char*>(message.data());
182   if (masking_key != 0) {
183     const char* mask_bytes = reinterpret_cast<char*>(&masking_key);
184     frame.insert(frame.end(), mask_bytes, mask_bytes + 4);
185     for (size_t i = 0; i < data_length; ++i)  // Mask the payload.
186       frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]);
187   } else {
188     frame.insert(frame.end(), data, data + data_length);
189   }
190   *output = std::string(frame.data(), frame.size());
191 }
192 
193 }  // anonymous namespace
194 
195 // static
CreateServer()196 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() {
197   return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
198 }
199 
200 // static
CreateServer(const std::string & extensions,WebSocketDeflateParameters * deflate_parameters)201 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer(
202     const std::string& extensions,
203     WebSocketDeflateParameters* deflate_parameters) {
204   WebSocketExtensionParser parser;
205   if (!parser.Parse(extensions)) {
206     // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the
207     // connection.
208     return nullptr;
209   }
210 
211   for (const auto& extension : parser.extensions()) {
212     std::string failure_message;
213     WebSocketDeflateParameters offer;
214     if (!offer.Initialize(extension, &failure_message) ||
215         !offer.IsValidAsRequest(&failure_message)) {
216       // We decline unknown / malformed extensions.
217       continue;
218     }
219 
220     WebSocketDeflateParameters response = offer;
221     if (offer.is_client_max_window_bits_specified() &&
222         !offer.has_client_max_window_bits_value()) {
223       // We need to choose one value for the response.
224       response.SetClientMaxWindowBits(15);
225     }
226     DCHECK(response.IsValidAsResponse());
227     DCHECK(offer.IsCompatibleWith(response));
228     auto deflater = std::make_unique<WebSocketDeflater>(
229         response.server_context_take_over_mode());
230     auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
231                                                         kInflaterChunkSize);
232     if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) ||
233         !inflater->Initialize(response.PermissiveClientMaxWindowBits())) {
234       // For some reason we cannot accept the parameters.
235       continue;
236     }
237     *deflate_parameters = response;
238     return base::WrapUnique(new WebSocketEncoder(
239         FOR_SERVER, std::move(deflater), std::move(inflater)));
240   }
241 
242   // We cannot find an acceptable offer.
243   return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
244 }
245 
246 // static
CreateClient(const std::string & response_extensions)247 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateClient(
248     const std::string& response_extensions) {
249   // TODO(yhirano): Add a way to return an error.
250 
251   WebSocketExtensionParser parser;
252   if (!parser.Parse(response_extensions)) {
253     // Parse error. Note that there are two cases here.
254     // 1) There is no Sec-WebSocket-Extensions header.
255     // 2) There is a malformed Sec-WebSocketExtensions header.
256     // We should return a deflate-disabled encoder for the former case and
257     // fail the connection for the latter case.
258     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
259   }
260   if (parser.extensions().size() != 1) {
261     // Only permessage-deflate extension is supported.
262     // TODO (yhirano): Fail the connection.
263     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
264   }
265   const auto& extension = parser.extensions()[0];
266   WebSocketDeflateParameters params;
267   std::string failure_message;
268   if (!params.Initialize(extension, &failure_message) ||
269       !params.IsValidAsResponse(&failure_message)) {
270     // TODO (yhirano): Fail the connection.
271     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
272   }
273 
274   auto deflater = std::make_unique<WebSocketDeflater>(
275       params.client_context_take_over_mode());
276   auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
277                                                       kInflaterChunkSize);
278   if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) ||
279       !inflater->Initialize(params.PermissiveServerMaxWindowBits())) {
280     // TODO (yhirano): Fail the connection.
281     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
282   }
283 
284   return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, std::move(deflater),
285                                                std::move(inflater)));
286 }
287 
WebSocketEncoder(Type type,std::unique_ptr<WebSocketDeflater> deflater,std::unique_ptr<WebSocketInflater> inflater)288 WebSocketEncoder::WebSocketEncoder(Type type,
289                                    std::unique_ptr<WebSocketDeflater> deflater,
290                                    std::unique_ptr<WebSocketInflater> inflater)
291     : type_(type),
292       deflater_(std::move(deflater)),
293       inflater_(std::move(inflater)) {}
294 
295 WebSocketEncoder::~WebSocketEncoder() = default;
296 
DecodeFrame(std::string_view frame,int * bytes_consumed,std::string * output)297 WebSocket::ParseResult WebSocketEncoder::DecodeFrame(std::string_view frame,
298                                                      int* bytes_consumed,
299                                                      std::string* output) {
300   bool compressed;
301   std::string current_output;
302   WebSocket::ParseResult result = DecodeFrameHybi17(
303       frame, type_ == FOR_SERVER, bytes_consumed, &current_output, &compressed);
304   switch (result) {
305     case WebSocket::FRAME_OK_FINAL:
306     case WebSocket::FRAME_OK_MIDDLE: {
307       if (continuation_message_frames_.empty())
308         is_current_message_compressed_ = compressed;
309       continuation_message_frames_.push_back(current_output);
310 
311       if (result == WebSocket::FRAME_OK_FINAL) {
312         *output = base::StrCat(continuation_message_frames_);
313         continuation_message_frames_.clear();
314         if (is_current_message_compressed_ && !Inflate(output)) {
315           return WebSocket::FRAME_ERROR;
316         }
317       }
318       break;
319     }
320 
321     case WebSocket::FRAME_PING:
322       *output = current_output;
323       break;
324 
325     default:
326       // This function doesn't need special handling for other parse results.
327       break;
328   }
329 
330   return result;
331 }
332 
EncodeTextFrame(std::string_view frame,int masking_key,std::string * output)333 void WebSocketEncoder::EncodeTextFrame(std::string_view frame,
334                                        int masking_key,
335                                        std::string* output) {
336   std::string compressed;
337   constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeText;
338   if (Deflate(frame, &compressed))
339     EncodeFrameHybi17(compressed, masking_key, true, op_code, output);
340   else
341     EncodeFrameHybi17(frame, masking_key, false, op_code, output);
342 }
343 
EncodeCloseFrame(std::string_view frame,int masking_key,std::string * output)344 void WebSocketEncoder::EncodeCloseFrame(std::string_view frame,
345                                         int masking_key,
346                                         std::string* output) {
347   constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeClose;
348   EncodeFrameHybi17(frame, masking_key, false, op_code, output);
349 }
350 
EncodePongFrame(std::string_view frame,int masking_key,std::string * output)351 void WebSocketEncoder::EncodePongFrame(std::string_view frame,
352                                        int masking_key,
353                                        std::string* output) {
354   constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodePong;
355   EncodeFrameHybi17(frame, masking_key, false, op_code, output);
356 }
357 
Inflate(std::string * message)358 bool WebSocketEncoder::Inflate(std::string* message) {
359   if (!inflater_)
360     return false;
361   if (!inflater_->AddBytes(message->data(), message->length()))
362     return false;
363   if (!inflater_->Finish())
364     return false;
365 
366   std::vector<char> output;
367   while (inflater_->CurrentOutputSize() > 0) {
368     scoped_refptr<IOBufferWithSize> chunk =
369         inflater_->GetOutput(inflater_->CurrentOutputSize());
370     if (!chunk.get())
371       return false;
372     output.insert(output.end(), chunk->data(), chunk->data() + chunk->size());
373   }
374 
375   *message =
376       output.size() ? std::string(output.data(), output.size()) : std::string();
377   return true;
378 }
379 
Deflate(std::string_view message,std::string * output)380 bool WebSocketEncoder::Deflate(std::string_view message, std::string* output) {
381   if (!deflater_)
382     return false;
383   if (!deflater_->AddBytes(message.data(), message.length())) {
384     deflater_->Finish();
385     return false;
386   }
387   if (!deflater_->Finish())
388     return false;
389   scoped_refptr<IOBufferWithSize> buffer =
390       deflater_->GetOutput(deflater_->CurrentOutputSize());
391   if (!buffer.get())
392     return false;
393   *output = std::string(buffer->data(), buffer->size());
394   return true;
395 }
396 
397 }  // namespace net
398