1 /*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ZstdUtil.h"
18
19 #include <android-base/logging.h>
20 #include <zstd.h>
21
22 namespace simpleperf {
23
24 namespace {
25
26 class CompressionOutBuffer {
27 public:
CompressionOutBuffer(size_t min_free_size)28 CompressionOutBuffer(size_t min_free_size)
29 : min_free_size_(min_free_size), buffer_(min_free_size) {}
30
DataStart() const31 const char* DataStart() const { return buffer_.data() + data_pos_; }
DataSize() const32 size_t DataSize() const { return data_size_; }
FreeStart()33 char* FreeStart() { return buffer_.data() + data_pos_ + data_size_; }
FreeSize() const34 size_t FreeSize() const { return buffer_.size() - data_pos_ - data_size_; }
35
PrepareForInput()36 void PrepareForInput() {
37 if (data_pos_ > 0) {
38 if (data_size_ == 0) {
39 data_pos_ = 0;
40 } else {
41 memmove(buffer_.data(), buffer_.data() + data_pos_, data_size_);
42 data_pos_ = 0;
43 }
44 }
45 if (FreeSize() < min_free_size_) {
46 buffer_.resize(buffer_.size() * 2);
47 }
48 }
49
ProduceData(size_t size)50 void ProduceData(size_t size) {
51 data_size_ += size;
52 CHECK_LE(data_pos_ + data_size_, buffer_.size());
53 }
54
ConsumeData(size_t size)55 void ConsumeData(size_t size) {
56 CHECK_LE(size, data_size_);
57 data_pos_ += size;
58 data_size_ -= size;
59 }
60
61 private:
62 const size_t min_free_size_;
63 std::vector<char> buffer_;
64 size_t data_pos_ = 0;
65 size_t data_size_ = 0;
66 };
67
68 using ZSTD_CCtx_pointer = std::unique_ptr<ZSTD_CCtx, decltype(&ZSTD_freeCCtx)>;
69
70 class ZstdCompressor : public Compressor {
71 public:
ZstdCompressor(ZSTD_CCtx_pointer cctx)72 ZstdCompressor(ZSTD_CCtx_pointer cctx)
73 : cctx_(std::move(cctx)), out_buffer_(ZSTD_CStreamOutSize()) {}
74
AddInputData(const char * data,size_t size)75 bool AddInputData(const char* data, size_t size) override {
76 ZSTD_inBuffer input = {data, size, 0};
77 while (input.pos < input.size) {
78 out_buffer_.PrepareForInput();
79 ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0};
80 size_t remaining = ZSTD_compressStream2(cctx_.get(), &output, &input, ZSTD_e_continue);
81 if (ZSTD_isError(remaining)) {
82 LOG(ERROR) << "ZSTD_compressStream2() failed: " << ZSTD_getErrorName(remaining);
83 return false;
84 }
85 out_buffer_.ProduceData(output.pos);
86 total_output_size_ += output.pos;
87 }
88 total_input_size_ += size;
89 return true;
90 }
91
FlushOutputData()92 bool FlushOutputData() override {
93 if (flushed_input_size_ == total_input_size_) {
94 return true;
95 }
96 flushed_input_size_ = total_input_size_;
97 ZSTD_inBuffer input = {nullptr, 0, 0};
98 size_t remaining = 0;
99 do {
100 out_buffer_.PrepareForInput();
101 ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0};
102 remaining = ZSTD_compressStream2(cctx_.get(), &output, &input, ZSTD_e_end);
103 if (ZSTD_isError(remaining)) {
104 LOG(ERROR) << "ZSTD_compressStream2() failed: " << ZSTD_getErrorName(remaining);
105 return false;
106 }
107 out_buffer_.ProduceData(output.pos);
108 total_output_size_ += output.pos;
109 } while (remaining != 0);
110 return true;
111 }
112
GetOutputData()113 std::string_view GetOutputData() override {
114 return std::string_view(out_buffer_.DataStart(), out_buffer_.DataSize());
115 }
116
ConsumeOutputData(size_t size)117 void ConsumeOutputData(size_t size) override { out_buffer_.ConsumeData(size); }
118
119 private:
120 ZSTD_CCtx_pointer cctx_;
121 CompressionOutBuffer out_buffer_;
122 uint64_t flushed_input_size_ = 0;
123 };
124
125 using ZSTD_DCtx_pointer = std::unique_ptr<ZSTD_DCtx, decltype(&ZSTD_freeDCtx)>;
126
127 class ZstdDecompressor : public Decompressor {
128 public:
ZstdDecompressor(ZSTD_DCtx_pointer dctx)129 ZstdDecompressor(ZSTD_DCtx_pointer dctx)
130 : dctx_(std::move(dctx)), out_buffer_(ZSTD_DStreamOutSize()) {}
131
AddInputData(const char * data,size_t size)132 bool AddInputData(const char* data, size_t size) override {
133 ZSTD_inBuffer input = {data, size, 0};
134 while (input.pos < input.size) {
135 out_buffer_.PrepareForInput();
136 ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0};
137 size_t remaining = ZSTD_decompressStream(dctx_.get(), &output, &input);
138 if (ZSTD_isError(remaining)) {
139 LOG(ERROR) << "ZSTD_decompressStream() failed: " << ZSTD_getErrorName(remaining);
140 return false;
141 }
142 out_buffer_.ProduceData(output.pos);
143 }
144 return true;
145 }
146
GetOutputData()147 std::string_view GetOutputData() override {
148 return std::string_view(out_buffer_.DataStart(), out_buffer_.DataSize());
149 }
150
ConsumeOutputData(size_t size)151 void ConsumeOutputData(size_t size) override { out_buffer_.ConsumeData(size); }
152
153 private:
154 ZSTD_DCtx_pointer dctx_;
155 CompressionOutBuffer out_buffer_;
156 };
157
158 } // namespace
159
~Compressor()160 Compressor::~Compressor() {}
161
~Decompressor()162 Decompressor::~Decompressor() {}
163
CreateZstdCompressor(size_t compression_level)164 std::unique_ptr<Compressor> CreateZstdCompressor(size_t compression_level) {
165 ZSTD_CCtx_pointer cctx(ZSTD_createCCtx(), ZSTD_freeCCtx);
166 if (!cctx) {
167 LOG(ERROR) << "ZSTD_createCCtx() failed";
168 return nullptr;
169 }
170 size_t err = ZSTD_CCtx_setParameter(cctx.get(), ZSTD_c_compressionLevel, compression_level);
171 if (ZSTD_isError(err)) {
172 LOG(ERROR) << "failed to set compression level: " << ZSTD_getErrorName(err);
173 return nullptr;
174 }
175 return std::unique_ptr<Compressor>(new ZstdCompressor(std::move(cctx)));
176 }
177
CreateZstdDecompressor()178 std::unique_ptr<Decompressor> CreateZstdDecompressor() {
179 ZSTD_DCtx_pointer dctx(ZSTD_createDCtx(), ZSTD_freeDCtx);
180 if (!dctx) {
181 LOG(ERROR) << "ZSTD_createDCtx() failed";
182 return nullptr;
183 }
184 return std::unique_ptr<Decompressor>(new ZstdDecompressor(std::move(dctx)));
185 }
ZstdCompress(const char * input_data,size_t input_size,std::string & output_data)186 bool ZstdCompress(const char* input_data, size_t input_size, std::string& output_data) {
187 std::unique_ptr<Compressor> compressor = CreateZstdCompressor();
188 CHECK(compressor != nullptr);
189 if (!compressor->AddInputData(input_data, input_size)) {
190 return false;
191 }
192 if (!compressor->FlushOutputData()) {
193 return false;
194 }
195 std::string_view output = compressor->GetOutputData();
196 output_data.clear();
197 output_data.insert(0, output.data(), output.size());
198 return true;
199 }
200
ZstdDecompress(const char * input_data,size_t input_size,std::string & output_data)201 bool ZstdDecompress(const char* input_data, size_t input_size, std::string& output_data) {
202 std::unique_ptr<Decompressor> decompressor = CreateZstdDecompressor();
203 CHECK(decompressor != nullptr);
204 if (!decompressor->AddInputData(input_data, input_size)) {
205 return false;
206 }
207 std::string_view output = decompressor->GetOutputData();
208 output_data.clear();
209 output_data.insert(0, output.data(), output.size());
210 return true;
211 }
212
213 } // namespace simpleperf
214