xref: /aosp_15_r20/external/cronet/net/filter/filter_source_stream.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2016 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/filter/filter_source_stream.h"
6 
7 #include <string_view>
8 #include <utility>
9 
10 #include "base/check_op.h"
11 #include "base/containers/fixed_flat_map.h"
12 #include "base/functional/bind.h"
13 #include "base/metrics/histogram_macros.h"
14 #include "base/notreached.h"
15 #include "base/numerics/safe_conversions.h"
16 #include "base/strings/string_util.h"
17 #include "components/miracle_parameter/common/public/miracle_parameter.h"
18 #include "net/base/io_buffer.h"
19 #include "net/base/net_errors.h"
20 
21 namespace net {
22 
23 namespace {
24 
25 constexpr char kDeflate[] = "deflate";
26 constexpr char kGZip[] = "gzip";
27 constexpr char kXGZip[] = "x-gzip";
28 constexpr char kBrotli[] = "br";
29 constexpr char kZstd[] = "zstd";
30 
31 BASE_FEATURE(kBufferSizeForFilterSourceStreamFeature,
32              "BufferSizeForFilterSourceStreamFeature",
33              base::FEATURE_ENABLED_BY_DEFAULT);
34 
35 MIRACLE_PARAMETER_FOR_INT(GetBufferSizeForFilterSourceStream,
36                           kBufferSizeForFilterSourceStreamFeature,
37                           "BufferSizeForFilterSourceStream",
38                           32 * 1024)
39 
40 }  // namespace
41 
FilterSourceStream(SourceType type,std::unique_ptr<SourceStream> upstream)42 FilterSourceStream::FilterSourceStream(SourceType type,
43                                        std::unique_ptr<SourceStream> upstream)
44     : SourceStream(type), upstream_(std::move(upstream)) {
45   DCHECK(upstream_);
46 }
47 
48 FilterSourceStream::~FilterSourceStream() = default;
49 
Read(IOBuffer * read_buffer,int read_buffer_size,CompletionOnceCallback callback)50 int FilterSourceStream::Read(IOBuffer* read_buffer,
51                              int read_buffer_size,
52                              CompletionOnceCallback callback) {
53   DCHECK_EQ(STATE_NONE, next_state_);
54   DCHECK(read_buffer);
55   DCHECK_LT(0, read_buffer_size);
56 
57   // Allocate a BlockBuffer during first Read().
58   if (!input_buffer_) {
59     input_buffer_ = base::MakeRefCounted<IOBufferWithSize>(
60         GetBufferSizeForFilterSourceStream());
61     // This is first Read(), start with reading data from |upstream_|.
62     next_state_ = STATE_READ_DATA;
63   } else {
64     // Otherwise start with filtering data, which will tell us whether this
65     // stream needs input data.
66     next_state_ = STATE_FILTER_DATA;
67   }
68 
69   output_buffer_ = read_buffer;
70   output_buffer_size_ = base::checked_cast<size_t>(read_buffer_size);
71   int rv = DoLoop(OK);
72 
73   if (rv == ERR_IO_PENDING)
74     callback_ = std::move(callback);
75   return rv;
76 }
77 
Description() const78 std::string FilterSourceStream::Description() const {
79   std::string next_type_string = upstream_->Description();
80   if (next_type_string.empty())
81     return GetTypeAsString();
82   return next_type_string + "," + GetTypeAsString();
83 }
84 
MayHaveMoreBytes() const85 bool FilterSourceStream::MayHaveMoreBytes() const {
86   return !upstream_end_reached_;
87 }
88 
ParseEncodingType(const std::string & encoding)89 FilterSourceStream::SourceType FilterSourceStream::ParseEncodingType(
90     const std::string& encoding) {
91   std::string lower_encoding = base::ToLowerASCII(encoding);
92   static constexpr auto kEncodingMap =
93       base::MakeFixedFlatMap<std::string_view, SourceType>({
94           {"", TYPE_NONE},
95           {kBrotli, TYPE_BROTLI},
96           {kDeflate, TYPE_DEFLATE},
97           {kGZip, TYPE_GZIP},
98           {kXGZip, TYPE_GZIP},
99           {kZstd, TYPE_ZSTD},
100       });
101   auto encoding_type = kEncodingMap.find(lower_encoding);
102   if (encoding_type == kEncodingMap.end()) {
103     return TYPE_UNKNOWN;
104   }
105   return encoding_type->second;
106 }
107 
DoLoop(int result)108 int FilterSourceStream::DoLoop(int result) {
109   DCHECK_NE(STATE_NONE, next_state_);
110 
111   int rv = result;
112   do {
113     State state = next_state_;
114     next_state_ = STATE_NONE;
115     switch (state) {
116       case STATE_READ_DATA:
117         rv = DoReadData();
118         break;
119       case STATE_READ_DATA_COMPLETE:
120         rv = DoReadDataComplete(rv);
121         break;
122       case STATE_FILTER_DATA:
123         DCHECK_LE(0, rv);
124         rv = DoFilterData();
125         break;
126       default:
127         NOTREACHED() << "bad state: " << state;
128         rv = ERR_UNEXPECTED;
129         break;
130     }
131   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
132   return rv;
133 }
134 
DoReadData()135 int FilterSourceStream::DoReadData() {
136   // Read more data means subclasses have consumed all input or this is the
137   // first read in which case the |drainable_input_buffer_| is not initialized.
138   DCHECK(drainable_input_buffer_ == nullptr ||
139          0 == drainable_input_buffer_->BytesRemaining());
140 
141   next_state_ = STATE_READ_DATA_COMPLETE;
142   // Use base::Unretained here is safe because |this| owns |upstream_|.
143   int rv =
144       upstream_->Read(input_buffer_.get(), GetBufferSizeForFilterSourceStream(),
145                       base::BindOnce(&FilterSourceStream::OnIOComplete,
146                                      base::Unretained(this)));
147 
148   return rv;
149 }
150 
DoReadDataComplete(int result)151 int FilterSourceStream::DoReadDataComplete(int result) {
152   DCHECK_NE(ERR_IO_PENDING, result);
153 
154   if (result >= OK) {
155     drainable_input_buffer_ =
156         base::MakeRefCounted<DrainableIOBuffer>(input_buffer_, result);
157     next_state_ = STATE_FILTER_DATA;
158   }
159   if (result <= OK)
160     upstream_end_reached_ = true;
161   return result;
162 }
163 
DoFilterData()164 int FilterSourceStream::DoFilterData() {
165   DCHECK(output_buffer_);
166   DCHECK(drainable_input_buffer_);
167 
168   size_t consumed_bytes = 0;
169   base::expected<size_t, Error> bytes_output = FilterData(
170       output_buffer_.get(), output_buffer_size_, drainable_input_buffer_.get(),
171       drainable_input_buffer_->BytesRemaining(), &consumed_bytes,
172       upstream_end_reached_);
173 
174   const auto bytes_remaining =
175       base::checked_cast<size_t>(drainable_input_buffer_->BytesRemaining());
176   if (bytes_output.has_value() && bytes_output.value() == 0) {
177     DCHECK_EQ(consumed_bytes, bytes_remaining);
178   } else {
179     DCHECK_LE(consumed_bytes, bytes_remaining);
180   }
181   // FilterData() is not allowed to return ERR_IO_PENDING.
182   if (!bytes_output.has_value())
183     DCHECK_NE(ERR_IO_PENDING, bytes_output.error());
184 
185   if (consumed_bytes > 0)
186     drainable_input_buffer_->DidConsume(consumed_bytes);
187 
188   // Received data or encountered an error.
189   if (!bytes_output.has_value()) {
190     CHECK_LT(bytes_output.error(), 0);
191     return bytes_output.error();
192   }
193   if (bytes_output.value() != 0)
194     return base::checked_cast<int>(bytes_output.value());
195 
196   // If no data is returned, continue reading if |this| needs more input.
197   if (NeedMoreData()) {
198     DCHECK_EQ(0, drainable_input_buffer_->BytesRemaining());
199     next_state_ = STATE_READ_DATA;
200   }
201   return 0;
202 }
203 
OnIOComplete(int result)204 void FilterSourceStream::OnIOComplete(int result) {
205   DCHECK_EQ(STATE_READ_DATA_COMPLETE, next_state_);
206 
207   int rv = DoLoop(result);
208   if (rv == ERR_IO_PENDING)
209     return;
210 
211   output_buffer_ = nullptr;
212   output_buffer_size_ = 0;
213 
214   std::move(callback_).Run(rv);
215 }
216 
NeedMoreData() const217 bool FilterSourceStream::NeedMoreData() const {
218   return !upstream_end_reached_;
219 }
220 
221 }  // namespace net
222