xref: /aosp_15_r20/system/extras/simpleperf/ZstdUtil.cpp (revision 288bf5226967eb3dac5cce6c939ccc2a7f2b4fe5)
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ZstdUtil.h"
18 
19 #include <android-base/logging.h>
20 #include <zstd.h>
21 
22 namespace simpleperf {
23 
24 namespace {
25 
26 class CompressionOutBuffer {
27  public:
CompressionOutBuffer(size_t min_free_size)28   CompressionOutBuffer(size_t min_free_size)
29       : min_free_size_(min_free_size), buffer_(min_free_size) {}
30 
DataStart() const31   const char* DataStart() const { return buffer_.data() + data_pos_; }
DataSize() const32   size_t DataSize() const { return data_size_; }
FreeStart()33   char* FreeStart() { return buffer_.data() + data_pos_ + data_size_; }
FreeSize() const34   size_t FreeSize() const { return buffer_.size() - data_pos_ - data_size_; }
35 
PrepareForInput()36   void PrepareForInput() {
37     if (data_pos_ > 0) {
38       if (data_size_ == 0) {
39         data_pos_ = 0;
40       } else {
41         memmove(buffer_.data(), buffer_.data() + data_pos_, data_size_);
42         data_pos_ = 0;
43       }
44     }
45     if (FreeSize() < min_free_size_) {
46       buffer_.resize(buffer_.size() * 2);
47     }
48   }
49 
ProduceData(size_t size)50   void ProduceData(size_t size) {
51     data_size_ += size;
52     CHECK_LE(data_pos_ + data_size_, buffer_.size());
53   }
54 
ConsumeData(size_t size)55   void ConsumeData(size_t size) {
56     CHECK_LE(size, data_size_);
57     data_pos_ += size;
58     data_size_ -= size;
59   }
60 
61  private:
62   const size_t min_free_size_;
63   std::vector<char> buffer_;
64   size_t data_pos_ = 0;
65   size_t data_size_ = 0;
66 };
67 
68 using ZSTD_CCtx_pointer = std::unique_ptr<ZSTD_CCtx, decltype(&ZSTD_freeCCtx)>;
69 
70 class ZstdCompressor : public Compressor {
71  public:
ZstdCompressor(ZSTD_CCtx_pointer cctx)72   ZstdCompressor(ZSTD_CCtx_pointer cctx)
73       : cctx_(std::move(cctx)), out_buffer_(ZSTD_CStreamOutSize()) {}
74 
AddInputData(const char * data,size_t size)75   bool AddInputData(const char* data, size_t size) override {
76     ZSTD_inBuffer input = {data, size, 0};
77     while (input.pos < input.size) {
78       out_buffer_.PrepareForInput();
79       ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0};
80       size_t remaining = ZSTD_compressStream2(cctx_.get(), &output, &input, ZSTD_e_continue);
81       if (ZSTD_isError(remaining)) {
82         LOG(ERROR) << "ZSTD_compressStream2() failed: " << ZSTD_getErrorName(remaining);
83         return false;
84       }
85       out_buffer_.ProduceData(output.pos);
86       total_output_size_ += output.pos;
87     }
88     total_input_size_ += size;
89     return true;
90   }
91 
FlushOutputData()92   bool FlushOutputData() override {
93     if (flushed_input_size_ == total_input_size_) {
94       return true;
95     }
96     flushed_input_size_ = total_input_size_;
97     ZSTD_inBuffer input = {nullptr, 0, 0};
98     size_t remaining = 0;
99     do {
100       out_buffer_.PrepareForInput();
101       ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0};
102       remaining = ZSTD_compressStream2(cctx_.get(), &output, &input, ZSTD_e_end);
103       if (ZSTD_isError(remaining)) {
104         LOG(ERROR) << "ZSTD_compressStream2() failed: " << ZSTD_getErrorName(remaining);
105         return false;
106       }
107       out_buffer_.ProduceData(output.pos);
108       total_output_size_ += output.pos;
109     } while (remaining != 0);
110     return true;
111   }
112 
GetOutputData()113   std::string_view GetOutputData() override {
114     return std::string_view(out_buffer_.DataStart(), out_buffer_.DataSize());
115   }
116 
ConsumeOutputData(size_t size)117   void ConsumeOutputData(size_t size) override { out_buffer_.ConsumeData(size); }
118 
119  private:
120   ZSTD_CCtx_pointer cctx_;
121   CompressionOutBuffer out_buffer_;
122   uint64_t flushed_input_size_ = 0;
123 };
124 
125 using ZSTD_DCtx_pointer = std::unique_ptr<ZSTD_DCtx, decltype(&ZSTD_freeDCtx)>;
126 
127 class ZstdDecompressor : public Decompressor {
128  public:
ZstdDecompressor(ZSTD_DCtx_pointer dctx)129   ZstdDecompressor(ZSTD_DCtx_pointer dctx)
130       : dctx_(std::move(dctx)), out_buffer_(ZSTD_DStreamOutSize()) {}
131 
AddInputData(const char * data,size_t size)132   bool AddInputData(const char* data, size_t size) override {
133     ZSTD_inBuffer input = {data, size, 0};
134     while (input.pos < input.size) {
135       out_buffer_.PrepareForInput();
136       ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0};
137       size_t remaining = ZSTD_decompressStream(dctx_.get(), &output, &input);
138       if (ZSTD_isError(remaining)) {
139         LOG(ERROR) << "ZSTD_decompressStream() failed: " << ZSTD_getErrorName(remaining);
140         return false;
141       }
142       out_buffer_.ProduceData(output.pos);
143     }
144     return true;
145   }
146 
GetOutputData()147   std::string_view GetOutputData() override {
148     return std::string_view(out_buffer_.DataStart(), out_buffer_.DataSize());
149   }
150 
ConsumeOutputData(size_t size)151   void ConsumeOutputData(size_t size) override { out_buffer_.ConsumeData(size); }
152 
153  private:
154   ZSTD_DCtx_pointer dctx_;
155   CompressionOutBuffer out_buffer_;
156 };
157 
158 }  // namespace
159 
~Compressor()160 Compressor::~Compressor() {}
161 
~Decompressor()162 Decompressor::~Decompressor() {}
163 
CreateZstdCompressor(size_t compression_level)164 std::unique_ptr<Compressor> CreateZstdCompressor(size_t compression_level) {
165   ZSTD_CCtx_pointer cctx(ZSTD_createCCtx(), ZSTD_freeCCtx);
166   if (!cctx) {
167     LOG(ERROR) << "ZSTD_createCCtx() failed";
168     return nullptr;
169   }
170   size_t err = ZSTD_CCtx_setParameter(cctx.get(), ZSTD_c_compressionLevel, compression_level);
171   if (ZSTD_isError(err)) {
172     LOG(ERROR) << "failed to set compression level: " << ZSTD_getErrorName(err);
173     return nullptr;
174   }
175   return std::unique_ptr<Compressor>(new ZstdCompressor(std::move(cctx)));
176 }
177 
CreateZstdDecompressor()178 std::unique_ptr<Decompressor> CreateZstdDecompressor() {
179   ZSTD_DCtx_pointer dctx(ZSTD_createDCtx(), ZSTD_freeDCtx);
180   if (!dctx) {
181     LOG(ERROR) << "ZSTD_createDCtx() failed";
182     return nullptr;
183   }
184   return std::unique_ptr<Decompressor>(new ZstdDecompressor(std::move(dctx)));
185 }
ZstdCompress(const char * input_data,size_t input_size,std::string & output_data)186 bool ZstdCompress(const char* input_data, size_t input_size, std::string& output_data) {
187   std::unique_ptr<Compressor> compressor = CreateZstdCompressor();
188   CHECK(compressor != nullptr);
189   if (!compressor->AddInputData(input_data, input_size)) {
190     return false;
191   }
192   if (!compressor->FlushOutputData()) {
193     return false;
194   }
195   std::string_view output = compressor->GetOutputData();
196   output_data.clear();
197   output_data.insert(0, output.data(), output.size());
198   return true;
199 }
200 
ZstdDecompress(const char * input_data,size_t input_size,std::string & output_data)201 bool ZstdDecompress(const char* input_data, size_t input_size, std::string& output_data) {
202   std::unique_ptr<Decompressor> decompressor = CreateZstdDecompressor();
203   CHECK(decompressor != nullptr);
204   if (!decompressor->AddInputData(input_data, input_size)) {
205     return false;
206   }
207   std::string_view output = decompressor->GetOutputData();
208   output_data.clear();
209   output_data.insert(0, output.data(), output.size());
210   return true;
211 }
212 
213 }  // namespace simpleperf
214