1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/filter/zstd_source_stream.h"
6
7 #include <algorithm>
8 #include <unordered_map>
9 #include <utility>
10
11 #define ZSTD_STATIC_LINKING_ONLY
12
13 #include "base/bits.h"
14 #include "base/check_op.h"
15 #include "base/metrics/histogram_macros.h"
16 #include "base/numerics/safe_conversions.h"
17 #include "net/base/io_buffer.h"
18 #include "third_party/zstd/src/lib/zstd.h"
19 #include "third_party/zstd/src/lib/zstd_errors.h"
20
21 namespace net {
22
23 namespace {
24
25 const char kZstd[] = "ZSTD";
26
27 struct FreeContextDeleter {
operator ()net::__anon170898e00111::FreeContextDeleter28 inline void operator()(ZSTD_DCtx* ptr) const { ZSTD_freeDCtx(ptr); }
29 };
30
31 // ZstdSourceStream applies Zstd content decoding to a data stream.
32 // Zstd format speciication: https://datatracker.ietf.org/doc/html/rfc8878
33 class ZstdSourceStream : public FilterSourceStream {
34 public:
ZstdSourceStream(std::unique_ptr<SourceStream> upstream,scoped_refptr<IOBuffer> dictionary=nullptr,size_t dictionary_size=0u)35 explicit ZstdSourceStream(std::unique_ptr<SourceStream> upstream,
36 scoped_refptr<IOBuffer> dictionary = nullptr,
37 size_t dictionary_size = 0u)
38 : FilterSourceStream(SourceStream::TYPE_ZSTD, std::move(upstream)),
39 dictionary_(std::move(dictionary)),
40 dictionary_size_(dictionary_size) {
41 ZSTD_customMem custom_mem = {&customMalloc, &customFree, this};
42 dctx_.reset(ZSTD_createDCtx_advanced(custom_mem));
43 CHECK(dctx_);
44
45 // Following RFC 8878 recommendation (see section 3.1.1.1.2 Window
46 // Descriptor) of using a maximum 8MB memory buffer to decompress frames
47 // to '... protect decoders from unreasonable memory requirements'.
48 int window_log_max = 23;
49 if (dictionary_) {
50 // For shared dictionary case, allow using larger window size (Log2Ceiling
51 // of `dictionary_size`). It is safe because we have the size limit per
52 // shared dictionary and the total dictionary size limit.
53 window_log_max =
54 std::max(base::bits::Log2Ceiling(
55 base::checked_cast<uint32_t>(dictionary_size_)),
56 window_log_max);
57 }
58 ZSTD_DCtx_setParameter(dctx_.get(), ZSTD_d_windowLogMax, window_log_max);
59 if (dictionary_) {
60 size_t result = ZSTD_DCtx_loadDictionary_advanced(
61 dctx_.get(), reinterpret_cast<const void*>(dictionary_->data()),
62 dictionary_size_, ZSTD_dlm_byRef, ZSTD_dct_rawContent);
63 DCHECK(!ZSTD_isError(result));
64 }
65 }
66
67 ZstdSourceStream(const ZstdSourceStream&) = delete;
68 ZstdSourceStream& operator=(const ZstdSourceStream&) = delete;
69
~ZstdSourceStream()70 ~ZstdSourceStream() override {
71 if (ZSTD_isError(decoding_result_)) {
72 ZSTD_ErrorCode error_code = ZSTD_getErrorCode(decoding_result_);
73 UMA_HISTOGRAM_ENUMERATION(
74 "Net.ZstdFilter.ErrorCode", static_cast<int>(error_code),
75 static_cast<int>(ZSTD_ErrorCode::ZSTD_error_maxCode));
76 }
77
78 UMA_HISTOGRAM_ENUMERATION("Net.ZstdFilter.Status", decoding_status_);
79
80 if (decoding_status_ == ZstdDecodingStatus::kEndOfFrame) {
81 // CompressionRatio is undefined when there is no output produced.
82 if (produced_bytes_ != 0) {
83 UMA_HISTOGRAM_PERCENTAGE(
84 "Net.ZstdFilter.CompressionRatio",
85 static_cast<int>((consumed_bytes_ * 100) / produced_bytes_));
86 }
87 }
88
89 UMA_HISTOGRAM_MEMORY_KB("Net.ZstdFilter.MaxMemoryUsage",
90 (max_allocated_ / 1024));
91 }
92
93 private:
customMalloc(void * opaque,size_t size)94 static void* customMalloc(void* opaque, size_t size) {
95 return reinterpret_cast<ZstdSourceStream*>(opaque)->customMalloc(size);
96 }
97
customMalloc(size_t size)98 void* customMalloc(size_t size) {
99 void* address = malloc(size);
100 CHECK(address);
101 malloc_sizes_.emplace(address, size);
102 total_allocated_ += size;
103 if (total_allocated_ > max_allocated_) {
104 max_allocated_ = total_allocated_;
105 }
106 return address;
107 }
108
customFree(void * opaque,void * address)109 static void customFree(void* opaque, void* address) {
110 return reinterpret_cast<ZstdSourceStream*>(opaque)->customFree(address);
111 }
112
customFree(void * address)113 void customFree(void* address) {
114 free(address);
115 auto it = malloc_sizes_.find(address);
116 CHECK(it != malloc_sizes_.end());
117 const size_t size = it->second;
118 total_allocated_ -= size;
119 malloc_sizes_.erase(it);
120 }
121
122 // SourceStream implementation
GetTypeAsString() const123 std::string GetTypeAsString() const override { return kZstd; }
124
FilterData(IOBuffer * output_buffer,size_t output_buffer_size,IOBuffer * input_buffer,size_t input_buffer_size,size_t * consumed_bytes,bool upstream_end_reached)125 base::expected<size_t, Error> FilterData(IOBuffer* output_buffer,
126 size_t output_buffer_size,
127 IOBuffer* input_buffer,
128 size_t input_buffer_size,
129 size_t* consumed_bytes,
130 bool upstream_end_reached) override {
131 CHECK(dctx_);
132 ZSTD_inBuffer input = {input_buffer->data(), input_buffer_size, 0};
133 ZSTD_outBuffer output = {output_buffer->data(), output_buffer_size, 0};
134
135 const size_t result = ZSTD_decompressStream(dctx_.get(), &output, &input);
136
137 decoding_result_ = result;
138
139 produced_bytes_ += output.pos;
140 consumed_bytes_ += input.pos;
141
142 *consumed_bytes = input.pos;
143
144 if (ZSTD_isError(result)) {
145 decoding_status_ = ZstdDecodingStatus::kDecodingError;
146 if (ZSTD_getErrorCode(result) ==
147 ZSTD_error_frameParameter_windowTooLarge) {
148 return base::unexpected(ERR_ZSTD_WINDOW_SIZE_TOO_BIG);
149 }
150 return base::unexpected(ERR_CONTENT_DECODING_FAILED);
151 } else if (input.pos < input.size) {
152 // Given a valid frame, zstd won't consume the last byte of the frame
153 // until it has flushed all of the decompressed data of the frame.
154 // Therefore, instead of checking if the return code is 0, we can
155 // just check if input.pos < input.size.
156 return output.pos;
157 } else {
158 CHECK_EQ(input.pos, input.size);
159 if (result != 0u) {
160 // The return value from ZSTD_decompressStream did not end on a frame,
161 // but we reached the end of the file. We assume this is an error, and
162 // the input was truncated.
163 if (upstream_end_reached) {
164 decoding_status_ = ZstdDecodingStatus::kDecodingError;
165 }
166 } else {
167 CHECK_EQ(result, 0u);
168 CHECK_LE(output.pos, output.size);
169 // Finished decoding a frame.
170 decoding_status_ = ZstdDecodingStatus::kEndOfFrame;
171 }
172 return output.pos;
173 }
174 }
175
176 size_t total_allocated_ = 0;
177 size_t max_allocated_ = 0;
178 std::unordered_map<void*, size_t> malloc_sizes_;
179
180 const scoped_refptr<IOBuffer> dictionary_;
181 const size_t dictionary_size_;
182
183 std::unique_ptr<ZSTD_DCtx, FreeContextDeleter> dctx_;
184
185 ZstdDecodingStatus decoding_status_ = ZstdDecodingStatus::kDecodingInProgress;
186
187 size_t decoding_result_ = 0;
188 size_t consumed_bytes_ = 0;
189 size_t produced_bytes_ = 0;
190 };
191
192 } // namespace
193
CreateZstdSourceStream(std::unique_ptr<SourceStream> previous)194 std::unique_ptr<FilterSourceStream> CreateZstdSourceStream(
195 std::unique_ptr<SourceStream> previous) {
196 return std::make_unique<ZstdSourceStream>(std::move(previous));
197 }
198
CreateZstdSourceStreamWithDictionary(std::unique_ptr<SourceStream> previous,scoped_refptr<IOBuffer> dictionary,size_t dictionary_size)199 std::unique_ptr<FilterSourceStream> CreateZstdSourceStreamWithDictionary(
200 std::unique_ptr<SourceStream> previous,
201 scoped_refptr<IOBuffer> dictionary,
202 size_t dictionary_size) {
203 return std::make_unique<ZstdSourceStream>(
204 std::move(previous), std::move(dictionary), dictionary_size);
205 }
206
207 } // namespace net
208