xref: /aosp_15_r20/external/cronet/net/websockets/websocket_deflate_stream.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_deflate_stream.h"
6 
7 #include <stdint.h>
8 
9 #include <algorithm>
10 #include <ostream>
11 #include <string>
12 #include <utility>
13 #include <vector>
14 
15 #include "base/check.h"
16 #include "base/check_op.h"
17 #include "base/functional/bind.h"
18 #include "base/functional/callback.h"
19 #include "base/logging.h"
20 #include "base/memory/scoped_refptr.h"
21 #include "base/notreached.h"
22 #include "net/base/io_buffer.h"
23 #include "net/base/net_errors.h"
24 #include "net/websockets/websocket_deflate_parameters.h"
25 #include "net/websockets/websocket_deflate_predictor.h"
26 #include "net/websockets/websocket_deflater.h"
27 #include "net/websockets/websocket_frame.h"
28 #include "net/websockets/websocket_inflater.h"
29 #include "net/websockets/websocket_stream.h"
30 
31 namespace net {
32 class NetLogWithSource;
33 
34 namespace {
35 
36 constexpr int kWindowBits = 15;
37 constexpr size_t kChunkSize = 4 * 1024;
38 
39 }  // namespace
40 
WebSocketDeflateStream(std::unique_ptr<WebSocketStream> stream,const WebSocketDeflateParameters & params,std::unique_ptr<WebSocketDeflatePredictor> predictor)41 WebSocketDeflateStream::WebSocketDeflateStream(
42     std::unique_ptr<WebSocketStream> stream,
43     const WebSocketDeflateParameters& params,
44     std::unique_ptr<WebSocketDeflatePredictor> predictor)
45     : stream_(std::move(stream)),
46       deflater_(params.client_context_take_over_mode()),
47       inflater_(kChunkSize, kChunkSize),
48       predictor_(std::move(predictor)) {
49   DCHECK(stream_);
50   DCHECK(params.IsValidAsResponse());
51   int client_max_window_bits = 15;
52   if (params.is_client_max_window_bits_specified()) {
53     DCHECK(params.has_client_max_window_bits_value());
54     client_max_window_bits = params.client_max_window_bits();
55   }
56   deflater_.Initialize(client_max_window_bits);
57   inflater_.Initialize(kWindowBits);
58 }
59 
60 WebSocketDeflateStream::~WebSocketDeflateStream() = default;
61 
ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)62 int WebSocketDeflateStream::ReadFrames(
63     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
64     CompletionOnceCallback callback) {
65   read_callback_ = std::move(callback);
66   inflater_outputs_.clear();
67   int result = stream_->ReadFrames(
68       frames, base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
69                              base::Unretained(this), base::Unretained(frames)));
70   if (result < 0)
71     return result;
72   DCHECK_EQ(OK, result);
73   DCHECK(!frames->empty());
74 
75   return InflateAndReadIfNecessary(frames);
76 }
77 
WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)78 int WebSocketDeflateStream::WriteFrames(
79     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
80     CompletionOnceCallback callback) {
81   deflater_outputs_.clear();
82   int result = Deflate(frames);
83   if (result != OK)
84     return result;
85   if (frames->empty())
86     return OK;
87   return stream_->WriteFrames(frames, std::move(callback));
88 }
89 
Close()90 void WebSocketDeflateStream::Close() { stream_->Close(); }
91 
GetSubProtocol() const92 std::string WebSocketDeflateStream::GetSubProtocol() const {
93   return stream_->GetSubProtocol();
94 }
95 
GetExtensions() const96 std::string WebSocketDeflateStream::GetExtensions() const {
97   return stream_->GetExtensions();
98 }
99 
GetNetLogWithSource() const100 const NetLogWithSource& WebSocketDeflateStream::GetNetLogWithSource() const {
101   return stream_->GetNetLogWithSource();
102 }
103 
OnReadComplete(std::vector<std::unique_ptr<WebSocketFrame>> * frames,int result)104 void WebSocketDeflateStream::OnReadComplete(
105     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
106     int result) {
107   if (result != OK) {
108     frames->clear();
109     std::move(read_callback_).Run(result);
110     return;
111   }
112 
113   int r = InflateAndReadIfNecessary(frames);
114   if (r != ERR_IO_PENDING)
115     std::move(read_callback_).Run(r);
116 }
117 
Deflate(std::vector<std::unique_ptr<WebSocketFrame>> * frames)118 int WebSocketDeflateStream::Deflate(
119     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
120   std::vector<std::unique_ptr<WebSocketFrame>> frames_to_write;
121   // Store frames of the currently processed message if writing_state_ equals to
122   // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
123   std::vector<std::unique_ptr<WebSocketFrame>> frames_of_message;
124   for (size_t i = 0; i < frames->size(); ++i) {
125     DCHECK(!(*frames)[i]->header.reserved1);
126     if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
127       frames_to_write.push_back(std::move((*frames)[i]));
128       continue;
129     }
130     if (writing_state_ == NOT_WRITING)
131       OnMessageStart(*frames, i);
132 
133     std::unique_ptr<WebSocketFrame> frame(std::move((*frames)[i]));
134     predictor_->RecordInputDataFrame(frame.get());
135 
136     if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
137       if (frame->header.final)
138         writing_state_ = NOT_WRITING;
139       predictor_->RecordWrittenDataFrame(frame.get());
140       frames_to_write.push_back(std::move(frame));
141       current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
142     } else {
143       if (frame->payload &&
144           !deflater_.AddBytes(
145               frame->payload,
146               static_cast<size_t>(frame->header.payload_length))) {
147         DVLOG(1) << "WebSocket protocol error. "
148                  << "deflater_.AddBytes() returns an error.";
149         return ERR_WS_PROTOCOL_ERROR;
150       }
151       if (frame->header.final && !deflater_.Finish()) {
152         DVLOG(1) << "WebSocket protocol error. "
153                  << "deflater_.Finish() returns an error.";
154         return ERR_WS_PROTOCOL_ERROR;
155       }
156 
157       if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
158         if (deflater_.CurrentOutputSize() >= kChunkSize ||
159             frame->header.final) {
160           int result = AppendCompressedFrame(frame->header, &frames_to_write);
161           if (result != OK)
162             return result;
163         }
164         if (frame->header.final)
165           writing_state_ = NOT_WRITING;
166       } else {
167         DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
168         bool final = frame->header.final;
169         frames_of_message.push_back(std::move(frame));
170         if (final) {
171           int result = AppendPossiblyCompressedMessage(&frames_of_message,
172                                                        &frames_to_write);
173           if (result != OK)
174             return result;
175           frames_of_message.clear();
176           writing_state_ = NOT_WRITING;
177         }
178       }
179     }
180   }
181   DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
182   frames->swap(frames_to_write);
183   return OK;
184 }
185 
OnMessageStart(const std::vector<std::unique_ptr<WebSocketFrame>> & frames,size_t index)186 void WebSocketDeflateStream::OnMessageStart(
187     const std::vector<std::unique_ptr<WebSocketFrame>>& frames,
188     size_t index) {
189   WebSocketFrame* frame = frames[index].get();
190   current_writing_opcode_ = frame->header.opcode;
191   DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
192          current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
193   WebSocketDeflatePredictor::Result prediction =
194       predictor_->Predict(frames, index);
195 
196   switch (prediction) {
197     case WebSocketDeflatePredictor::DEFLATE:
198       writing_state_ = WRITING_COMPRESSED_MESSAGE;
199       return;
200     case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
201       writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
202       return;
203     case WebSocketDeflatePredictor::TRY_DEFLATE:
204       writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
205       return;
206   }
207   NOTREACHED();
208 }
209 
AppendCompressedFrame(const WebSocketFrameHeader & header,std::vector<std::unique_ptr<WebSocketFrame>> * frames_to_write)210 int WebSocketDeflateStream::AppendCompressedFrame(
211     const WebSocketFrameHeader& header,
212     std::vector<std::unique_ptr<WebSocketFrame>>* frames_to_write) {
213   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
214   scoped_refptr<IOBufferWithSize> compressed_payload =
215       deflater_.GetOutput(deflater_.CurrentOutputSize());
216   if (!compressed_payload.get()) {
217     DVLOG(1) << "WebSocket protocol error. "
218              << "deflater_.GetOutput() returns an error.";
219     return ERR_WS_PROTOCOL_ERROR;
220   }
221   deflater_outputs_.push_back(compressed_payload);
222   auto compressed = std::make_unique<WebSocketFrame>(opcode);
223   compressed->header.CopyFrom(header);
224   compressed->header.opcode = opcode;
225   compressed->header.final = header.final;
226   compressed->header.reserved1 =
227       (opcode != WebSocketFrameHeader::kOpCodeContinuation);
228   compressed->payload = compressed_payload->data();
229   compressed->header.payload_length = compressed_payload->size();
230 
231   current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
232   predictor_->RecordWrittenDataFrame(compressed.get());
233   frames_to_write->push_back(std::move(compressed));
234   return OK;
235 }
236 
AppendPossiblyCompressedMessage(std::vector<std::unique_ptr<WebSocketFrame>> * frames,std::vector<std::unique_ptr<WebSocketFrame>> * frames_to_write)237 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
238     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
239     std::vector<std::unique_ptr<WebSocketFrame>>* frames_to_write) {
240   DCHECK(!frames->empty());
241 
242   const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
243   scoped_refptr<IOBufferWithSize> compressed_payload =
244       deflater_.GetOutput(deflater_.CurrentOutputSize());
245   if (!compressed_payload.get()) {
246     DVLOG(1) << "WebSocket protocol error. "
247              << "deflater_.GetOutput() returns an error.";
248     return ERR_WS_PROTOCOL_ERROR;
249   }
250   deflater_outputs_.push_back(compressed_payload);
251 
252   uint64_t original_payload_length = 0;
253   for (size_t i = 0; i < frames->size(); ++i) {
254     WebSocketFrame* frame = (*frames)[i].get();
255     // Asserts checking that frames represent one whole data message.
256     DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
257     DCHECK_EQ(i == 0,
258               WebSocketFrameHeader::kOpCodeContinuation !=
259               frame->header.opcode);
260     DCHECK_EQ(i == frames->size() - 1, frame->header.final);
261     original_payload_length += frame->header.payload_length;
262   }
263   if (original_payload_length <=
264       static_cast<uint64_t>(compressed_payload->size())) {
265     // Compression is not effective. Use the original frames.
266     for (auto& frame : *frames) {
267       predictor_->RecordWrittenDataFrame(frame.get());
268       frames_to_write->push_back(std::move(frame));
269     }
270     frames->clear();
271     return OK;
272   }
273   auto compressed = std::make_unique<WebSocketFrame>(opcode);
274   compressed->header.CopyFrom((*frames)[0]->header);
275   compressed->header.opcode = opcode;
276   compressed->header.final = true;
277   compressed->header.reserved1 = true;
278   compressed->payload = compressed_payload->data();
279   compressed->header.payload_length = compressed_payload->size();
280 
281   predictor_->RecordWrittenDataFrame(compressed.get());
282   frames_to_write->push_back(std::move(compressed));
283   return OK;
284 }
285 
Inflate(std::vector<std::unique_ptr<WebSocketFrame>> * frames)286 int WebSocketDeflateStream::Inflate(
287     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
288   std::vector<std::unique_ptr<WebSocketFrame>> frames_to_output;
289   std::vector<std::unique_ptr<WebSocketFrame>> frames_passed;
290   frames->swap(frames_passed);
291   for (auto& frame_passed : frames_passed) {
292     std::unique_ptr<WebSocketFrame> frame(std::move(frame_passed));
293     frame_passed = nullptr;
294     DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
295              << " final=" << frame->header.final
296              << " reserved1=" << frame->header.reserved1
297              << " payload_length=" << frame->header.payload_length;
298 
299     if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
300       frames_to_output.push_back(std::move(frame));
301       continue;
302     }
303 
304     if (reading_state_ == NOT_READING) {
305       if (frame->header.reserved1)
306         reading_state_ = READING_COMPRESSED_MESSAGE;
307       else
308         reading_state_ = READING_UNCOMPRESSED_MESSAGE;
309       current_reading_opcode_ = frame->header.opcode;
310     } else {
311       if (frame->header.reserved1) {
312         DVLOG(1) << "WebSocket protocol error. "
313                  << "Receiving a non-first frame with RSV1 flag set.";
314         return ERR_WS_PROTOCOL_ERROR;
315       }
316     }
317 
318     if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
319       if (frame->header.final)
320         reading_state_ = NOT_READING;
321       current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
322       frames_to_output.push_back(std::move(frame));
323     } else {
324       DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
325       if (frame->payload &&
326           !inflater_.AddBytes(
327               frame->payload,
328               static_cast<size_t>(frame->header.payload_length))) {
329         DVLOG(1) << "WebSocket protocol error. "
330                  << "inflater_.AddBytes() returns an error.";
331         return ERR_WS_PROTOCOL_ERROR;
332       }
333       if (frame->header.final) {
334         if (!inflater_.Finish()) {
335           DVLOG(1) << "WebSocket protocol error. "
336                    << "inflater_.Finish() returns an error.";
337           return ERR_WS_PROTOCOL_ERROR;
338         }
339       }
340       // TODO(yhirano): Many frames can be generated by the inflater and
341       // memory consumption can grow.
342       // We could avoid it, but avoiding it makes this class much more
343       // complicated.
344       while (inflater_.CurrentOutputSize() >= kChunkSize ||
345              frame->header.final) {
346         size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
347         auto inflated =
348             std::make_unique<WebSocketFrame>(WebSocketFrameHeader::kOpCodeText);
349         scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
350         inflater_outputs_.push_back(data);
351         bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
352         if (!data.get()) {
353           DVLOG(1) << "WebSocket protocol error. "
354                    << "inflater_.GetOutput() returns an error.";
355           return ERR_WS_PROTOCOL_ERROR;
356         }
357         inflated->header.CopyFrom(frame->header);
358         inflated->header.opcode = current_reading_opcode_;
359         inflated->header.final = is_final;
360         inflated->header.reserved1 = false;
361         inflated->payload = data->data();
362         inflated->header.payload_length = data->size();
363         DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
364                  << " final=" << inflated->header.final
365                  << " reserved1=" << inflated->header.reserved1
366                  << " payload_length=" << inflated->header.payload_length;
367         frames_to_output.push_back(std::move(inflated));
368         current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
369         if (is_final)
370           break;
371       }
372       if (frame->header.final)
373         reading_state_ = NOT_READING;
374     }
375   }
376   frames->swap(frames_to_output);
377   return frames->empty() ? ERR_IO_PENDING : OK;
378 }
379 
InflateAndReadIfNecessary(std::vector<std::unique_ptr<WebSocketFrame>> * frames)380 int WebSocketDeflateStream::InflateAndReadIfNecessary(
381     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
382   int result = Inflate(frames);
383   while (result == ERR_IO_PENDING) {
384     DCHECK(frames->empty());
385 
386     result = stream_->ReadFrames(
387         frames,
388         base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
389                        base::Unretained(this), base::Unretained(frames)));
390     if (result < 0)
391       break;
392     DCHECK_EQ(OK, result);
393     DCHECK(!frames->empty());
394 
395     result = Inflate(frames);
396   }
397   if (result < 0)
398     frames->clear();
399   return result;
400 }
401 
402 }  // namespace net
403