xref: /aosp_15_r20/external/cronet/net/filter/zstd_source_stream.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/filter/zstd_source_stream.h"
6 
7 #include <algorithm>
8 #include <unordered_map>
9 #include <utility>
10 
11 #define ZSTD_STATIC_LINKING_ONLY
12 
13 #include "base/bits.h"
14 #include "base/check_op.h"
15 #include "base/metrics/histogram_macros.h"
16 #include "base/numerics/safe_conversions.h"
17 #include "net/base/io_buffer.h"
18 #include "third_party/zstd/src/lib/zstd.h"
19 #include "third_party/zstd/src/lib/zstd_errors.h"
20 
21 namespace net {
22 
23 namespace {
24 
25 const char kZstd[] = "ZSTD";
26 
27 struct FreeContextDeleter {
operator ()net::__anon170898e00111::FreeContextDeleter28   inline void operator()(ZSTD_DCtx* ptr) const { ZSTD_freeDCtx(ptr); }
29 };
30 
31 // ZstdSourceStream applies Zstd content decoding to a data stream.
32 // Zstd format speciication: https://datatracker.ietf.org/doc/html/rfc8878
33 class ZstdSourceStream : public FilterSourceStream {
34  public:
ZstdSourceStream(std::unique_ptr<SourceStream> upstream,scoped_refptr<IOBuffer> dictionary=nullptr,size_t dictionary_size=0u)35   explicit ZstdSourceStream(std::unique_ptr<SourceStream> upstream,
36                             scoped_refptr<IOBuffer> dictionary = nullptr,
37                             size_t dictionary_size = 0u)
38       : FilterSourceStream(SourceStream::TYPE_ZSTD, std::move(upstream)),
39         dictionary_(std::move(dictionary)),
40         dictionary_size_(dictionary_size) {
41     ZSTD_customMem custom_mem = {&customMalloc, &customFree, this};
42     dctx_.reset(ZSTD_createDCtx_advanced(custom_mem));
43     CHECK(dctx_);
44 
45     // Following RFC 8878 recommendation (see section 3.1.1.1.2 Window
46     // Descriptor) of using a maximum 8MB memory buffer to decompress frames
47     // to '... protect decoders from unreasonable memory requirements'.
48     int window_log_max = 23;
49     if (dictionary_) {
50       // For shared dictionary case, allow using larger window size (Log2Ceiling
51       // of `dictionary_size`). It is safe because we have the size limit per
52       // shared dictionary and the total dictionary size limit.
53       window_log_max =
54           std::max(base::bits::Log2Ceiling(
55                        base::checked_cast<uint32_t>(dictionary_size_)),
56                    window_log_max);
57     }
58     ZSTD_DCtx_setParameter(dctx_.get(), ZSTD_d_windowLogMax, window_log_max);
59     if (dictionary_) {
60       size_t result = ZSTD_DCtx_loadDictionary_advanced(
61           dctx_.get(), reinterpret_cast<const void*>(dictionary_->data()),
62           dictionary_size_, ZSTD_dlm_byRef, ZSTD_dct_rawContent);
63       DCHECK(!ZSTD_isError(result));
64     }
65   }
66 
67   ZstdSourceStream(const ZstdSourceStream&) = delete;
68   ZstdSourceStream& operator=(const ZstdSourceStream&) = delete;
69 
~ZstdSourceStream()70   ~ZstdSourceStream() override {
71     if (ZSTD_isError(decoding_result_)) {
72       ZSTD_ErrorCode error_code = ZSTD_getErrorCode(decoding_result_);
73       UMA_HISTOGRAM_ENUMERATION(
74           "Net.ZstdFilter.ErrorCode", static_cast<int>(error_code),
75           static_cast<int>(ZSTD_ErrorCode::ZSTD_error_maxCode));
76     }
77 
78     UMA_HISTOGRAM_ENUMERATION("Net.ZstdFilter.Status", decoding_status_);
79 
80     if (decoding_status_ == ZstdDecodingStatus::kEndOfFrame) {
81       // CompressionRatio is undefined when there is no output produced.
82       if (produced_bytes_ != 0) {
83         UMA_HISTOGRAM_PERCENTAGE(
84             "Net.ZstdFilter.CompressionRatio",
85             static_cast<int>((consumed_bytes_ * 100) / produced_bytes_));
86       }
87     }
88 
89     UMA_HISTOGRAM_MEMORY_KB("Net.ZstdFilter.MaxMemoryUsage",
90                             (max_allocated_ / 1024));
91   }
92 
93  private:
customMalloc(void * opaque,size_t size)94   static void* customMalloc(void* opaque, size_t size) {
95     return reinterpret_cast<ZstdSourceStream*>(opaque)->customMalloc(size);
96   }
97 
customMalloc(size_t size)98   void* customMalloc(size_t size) {
99     void* address = malloc(size);
100     CHECK(address);
101     malloc_sizes_.emplace(address, size);
102     total_allocated_ += size;
103     if (total_allocated_ > max_allocated_) {
104       max_allocated_ = total_allocated_;
105     }
106     return address;
107   }
108 
customFree(void * opaque,void * address)109   static void customFree(void* opaque, void* address) {
110     return reinterpret_cast<ZstdSourceStream*>(opaque)->customFree(address);
111   }
112 
customFree(void * address)113   void customFree(void* address) {
114     free(address);
115     auto it = malloc_sizes_.find(address);
116     CHECK(it != malloc_sizes_.end());
117     const size_t size = it->second;
118     total_allocated_ -= size;
119     malloc_sizes_.erase(it);
120   }
121 
122   // SourceStream implementation
GetTypeAsString() const123   std::string GetTypeAsString() const override { return kZstd; }
124 
FilterData(IOBuffer * output_buffer,size_t output_buffer_size,IOBuffer * input_buffer,size_t input_buffer_size,size_t * consumed_bytes,bool upstream_end_reached)125   base::expected<size_t, Error> FilterData(IOBuffer* output_buffer,
126                                            size_t output_buffer_size,
127                                            IOBuffer* input_buffer,
128                                            size_t input_buffer_size,
129                                            size_t* consumed_bytes,
130                                            bool upstream_end_reached) override {
131     CHECK(dctx_);
132     ZSTD_inBuffer input = {input_buffer->data(), input_buffer_size, 0};
133     ZSTD_outBuffer output = {output_buffer->data(), output_buffer_size, 0};
134 
135     const size_t result = ZSTD_decompressStream(dctx_.get(), &output, &input);
136 
137     decoding_result_ = result;
138 
139     produced_bytes_ += output.pos;
140     consumed_bytes_ += input.pos;
141 
142     *consumed_bytes = input.pos;
143 
144     if (ZSTD_isError(result)) {
145       decoding_status_ = ZstdDecodingStatus::kDecodingError;
146       if (ZSTD_getErrorCode(result) ==
147           ZSTD_error_frameParameter_windowTooLarge) {
148         return base::unexpected(ERR_ZSTD_WINDOW_SIZE_TOO_BIG);
149       }
150       return base::unexpected(ERR_CONTENT_DECODING_FAILED);
151     } else if (input.pos < input.size) {
152       // Given a valid frame, zstd won't consume the last byte of the frame
153       // until it has flushed all of the decompressed data of the frame.
154       // Therefore, instead of checking if the return code is 0, we can
155       // just check if input.pos < input.size.
156       return output.pos;
157     } else {
158       CHECK_EQ(input.pos, input.size);
159       if (result != 0u) {
160         // The return value from ZSTD_decompressStream did not end on a frame,
161         // but we reached the end of the file. We assume this is an error, and
162         // the input was truncated.
163         if (upstream_end_reached) {
164           decoding_status_ = ZstdDecodingStatus::kDecodingError;
165         }
166       } else {
167         CHECK_EQ(result, 0u);
168         CHECK_LE(output.pos, output.size);
169         // Finished decoding a frame.
170         decoding_status_ = ZstdDecodingStatus::kEndOfFrame;
171       }
172       return output.pos;
173     }
174   }
175 
176   size_t total_allocated_ = 0;
177   size_t max_allocated_ = 0;
178   std::unordered_map<void*, size_t> malloc_sizes_;
179 
180   const scoped_refptr<IOBuffer> dictionary_;
181   const size_t dictionary_size_;
182 
183   std::unique_ptr<ZSTD_DCtx, FreeContextDeleter> dctx_;
184 
185   ZstdDecodingStatus decoding_status_ = ZstdDecodingStatus::kDecodingInProgress;
186 
187   size_t decoding_result_ = 0;
188   size_t consumed_bytes_ = 0;
189   size_t produced_bytes_ = 0;
190 };
191 
192 }  // namespace
193 
CreateZstdSourceStream(std::unique_ptr<SourceStream> previous)194 std::unique_ptr<FilterSourceStream> CreateZstdSourceStream(
195     std::unique_ptr<SourceStream> previous) {
196   return std::make_unique<ZstdSourceStream>(std::move(previous));
197 }
198 
CreateZstdSourceStreamWithDictionary(std::unique_ptr<SourceStream> previous,scoped_refptr<IOBuffer> dictionary,size_t dictionary_size)199 std::unique_ptr<FilterSourceStream> CreateZstdSourceStreamWithDictionary(
200     std::unique_ptr<SourceStream> previous,
201     scoped_refptr<IOBuffer> dictionary,
202     size_t dictionary_size) {
203   return std::make_unique<ZstdSourceStream>(
204       std::move(previous), std::move(dictionary), dictionary_size);
205 }
206 
207 }  // namespace net
208