xref: /aosp_15_r20/external/zstd/contrib/pzstd/Pzstd.cpp (revision 01826a4963a0d8a59bc3812d29bdf0fb76416722)
1*01826a49SYabin Cui /*
2*01826a49SYabin Cui  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*01826a49SYabin Cui  * All rights reserved.
4*01826a49SYabin Cui  *
5*01826a49SYabin Cui  * This source code is licensed under both the BSD-style license (found in the
6*01826a49SYabin Cui  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7*01826a49SYabin Cui  * in the COPYING file in the root directory of this source tree).
8*01826a49SYabin Cui  */
9*01826a49SYabin Cui #include "platform.h"   /* Large Files support, SET_BINARY_MODE */
10*01826a49SYabin Cui #include "Pzstd.h"
11*01826a49SYabin Cui #include "SkippableFrame.h"
12*01826a49SYabin Cui #include "utils/FileSystem.h"
13*01826a49SYabin Cui #include "utils/Portability.h"
14*01826a49SYabin Cui #include "utils/Range.h"
15*01826a49SYabin Cui #include "utils/ScopeGuard.h"
16*01826a49SYabin Cui #include "utils/ThreadPool.h"
17*01826a49SYabin Cui #include "utils/WorkQueue.h"
18*01826a49SYabin Cui 
19*01826a49SYabin Cui #include <algorithm>
20*01826a49SYabin Cui #include <chrono>
21*01826a49SYabin Cui #include <cinttypes>
22*01826a49SYabin Cui #include <cstddef>
23*01826a49SYabin Cui #include <cstdio>
24*01826a49SYabin Cui #include <memory>
25*01826a49SYabin Cui #include <string>
26*01826a49SYabin Cui 
27*01826a49SYabin Cui 
28*01826a49SYabin Cui namespace pzstd {
29*01826a49SYabin Cui 
30*01826a49SYabin Cui namespace {
31*01826a49SYabin Cui #ifdef _WIN32
32*01826a49SYabin Cui const std::string nullOutput = "nul";
33*01826a49SYabin Cui #else
34*01826a49SYabin Cui const std::string nullOutput = "/dev/null";
35*01826a49SYabin Cui #endif
36*01826a49SYabin Cui }
37*01826a49SYabin Cui 
38*01826a49SYabin Cui using std::size_t;
39*01826a49SYabin Cui 
fileSizeOrZero(const std::string & file)40*01826a49SYabin Cui static std::uintmax_t fileSizeOrZero(const std::string &file) {
41*01826a49SYabin Cui   if (file == "-") {
42*01826a49SYabin Cui     return 0;
43*01826a49SYabin Cui   }
44*01826a49SYabin Cui   std::error_code ec;
45*01826a49SYabin Cui   auto size = file_size(file, ec);
46*01826a49SYabin Cui   if (ec) {
47*01826a49SYabin Cui     size = 0;
48*01826a49SYabin Cui   }
49*01826a49SYabin Cui   return size;
50*01826a49SYabin Cui }
51*01826a49SYabin Cui 
handleOneInput(const Options & options,const std::string & inputFile,FILE * inputFd,const std::string & outputFile,FILE * outputFd,SharedState & state)52*01826a49SYabin Cui static std::uint64_t handleOneInput(const Options &options,
53*01826a49SYabin Cui                              const std::string &inputFile,
54*01826a49SYabin Cui                              FILE* inputFd,
55*01826a49SYabin Cui                              const std::string &outputFile,
56*01826a49SYabin Cui                              FILE* outputFd,
57*01826a49SYabin Cui                              SharedState& state) {
58*01826a49SYabin Cui   auto inputSize = fileSizeOrZero(inputFile);
59*01826a49SYabin Cui   // WorkQueue outlives ThreadPool so in the case of error we are certain
60*01826a49SYabin Cui   // we don't accidentally try to call push() on it after it is destroyed
61*01826a49SYabin Cui   WorkQueue<std::shared_ptr<BufferWorkQueue>> outs{options.numThreads + 1};
62*01826a49SYabin Cui   std::uint64_t bytesRead;
63*01826a49SYabin Cui   std::uint64_t bytesWritten;
64*01826a49SYabin Cui   {
65*01826a49SYabin Cui     // Initialize the (de)compression thread pool with numThreads
66*01826a49SYabin Cui     ThreadPool executor(options.numThreads);
67*01826a49SYabin Cui     // Run the reader thread on an extra thread
68*01826a49SYabin Cui     ThreadPool readExecutor(1);
69*01826a49SYabin Cui     if (!options.decompress) {
70*01826a49SYabin Cui       // Add a job that reads the input and starts all the compression jobs
71*01826a49SYabin Cui       readExecutor.add(
72*01826a49SYabin Cui           [&state, &outs, &executor, inputFd, inputSize, &options, &bytesRead] {
73*01826a49SYabin Cui             bytesRead = asyncCompressChunks(
74*01826a49SYabin Cui                 state,
75*01826a49SYabin Cui                 outs,
76*01826a49SYabin Cui                 executor,
77*01826a49SYabin Cui                 inputFd,
78*01826a49SYabin Cui                 inputSize,
79*01826a49SYabin Cui                 options.numThreads,
80*01826a49SYabin Cui                 options.determineParameters());
81*01826a49SYabin Cui           });
82*01826a49SYabin Cui       // Start writing
83*01826a49SYabin Cui       bytesWritten = writeFile(state, outs, outputFd, options.decompress);
84*01826a49SYabin Cui     } else {
85*01826a49SYabin Cui       // Add a job that reads the input and starts all the decompression jobs
86*01826a49SYabin Cui       readExecutor.add([&state, &outs, &executor, inputFd, &bytesRead] {
87*01826a49SYabin Cui         bytesRead = asyncDecompressFrames(state, outs, executor, inputFd);
88*01826a49SYabin Cui       });
89*01826a49SYabin Cui       // Start writing
90*01826a49SYabin Cui       bytesWritten = writeFile(state, outs, outputFd, options.decompress);
91*01826a49SYabin Cui     }
92*01826a49SYabin Cui   }
93*01826a49SYabin Cui   if (!state.errorHolder.hasError()) {
94*01826a49SYabin Cui     std::string inputFileName = inputFile == "-" ? "stdin" : inputFile;
95*01826a49SYabin Cui     std::string outputFileName = outputFile == "-" ? "stdout" : outputFile;
96*01826a49SYabin Cui     if (!options.decompress) {
97*01826a49SYabin Cui       double ratio = static_cast<double>(bytesWritten) /
98*01826a49SYabin Cui                      static_cast<double>(bytesRead + !bytesRead);
99*01826a49SYabin Cui       state.log(kLogInfo, "%-20s :%6.2f%%   (%6" PRIu64 " => %6" PRIu64
100*01826a49SYabin Cui                    " bytes, %s)\n",
101*01826a49SYabin Cui                    inputFileName.c_str(), ratio * 100, bytesRead, bytesWritten,
102*01826a49SYabin Cui                    outputFileName.c_str());
103*01826a49SYabin Cui     } else {
104*01826a49SYabin Cui       state.log(kLogInfo, "%-20s: %" PRIu64 " bytes \n",
105*01826a49SYabin Cui                    inputFileName.c_str(),bytesWritten);
106*01826a49SYabin Cui     }
107*01826a49SYabin Cui   }
108*01826a49SYabin Cui   return bytesWritten;
109*01826a49SYabin Cui }
110*01826a49SYabin Cui 
openInputFile(const std::string & inputFile,ErrorHolder & errorHolder)111*01826a49SYabin Cui static FILE *openInputFile(const std::string &inputFile,
112*01826a49SYabin Cui                            ErrorHolder &errorHolder) {
113*01826a49SYabin Cui   if (inputFile == "-") {
114*01826a49SYabin Cui     SET_BINARY_MODE(stdin);
115*01826a49SYabin Cui     return stdin;
116*01826a49SYabin Cui   }
117*01826a49SYabin Cui   // Check if input file is a directory
118*01826a49SYabin Cui   {
119*01826a49SYabin Cui     std::error_code ec;
120*01826a49SYabin Cui     if (is_directory(inputFile, ec)) {
121*01826a49SYabin Cui       errorHolder.setError("Output file is a directory -- ignored");
122*01826a49SYabin Cui       return nullptr;
123*01826a49SYabin Cui     }
124*01826a49SYabin Cui   }
125*01826a49SYabin Cui   auto inputFd = std::fopen(inputFile.c_str(), "rb");
126*01826a49SYabin Cui   if (!errorHolder.check(inputFd != nullptr, "Failed to open input file")) {
127*01826a49SYabin Cui     return nullptr;
128*01826a49SYabin Cui   }
129*01826a49SYabin Cui   return inputFd;
130*01826a49SYabin Cui }
131*01826a49SYabin Cui 
openOutputFile(const Options & options,const std::string & outputFile,SharedState & state)132*01826a49SYabin Cui static FILE *openOutputFile(const Options &options,
133*01826a49SYabin Cui                             const std::string &outputFile,
134*01826a49SYabin Cui                             SharedState& state) {
135*01826a49SYabin Cui   if (outputFile == "-") {
136*01826a49SYabin Cui     SET_BINARY_MODE(stdout);
137*01826a49SYabin Cui     return stdout;
138*01826a49SYabin Cui   }
139*01826a49SYabin Cui   // Check if the output file exists and then open it
140*01826a49SYabin Cui   if (!options.overwrite && outputFile != nullOutput) {
141*01826a49SYabin Cui     auto outputFd = std::fopen(outputFile.c_str(), "rb");
142*01826a49SYabin Cui     if (outputFd != nullptr) {
143*01826a49SYabin Cui       std::fclose(outputFd);
144*01826a49SYabin Cui       if (!state.log.logsAt(kLogInfo)) {
145*01826a49SYabin Cui         state.errorHolder.setError("Output file exists");
146*01826a49SYabin Cui         return nullptr;
147*01826a49SYabin Cui       }
148*01826a49SYabin Cui       state.log(
149*01826a49SYabin Cui           kLogInfo,
150*01826a49SYabin Cui           "pzstd: %s already exists; do you wish to overwrite (y/n) ? ",
151*01826a49SYabin Cui           outputFile.c_str());
152*01826a49SYabin Cui       int c = getchar();
153*01826a49SYabin Cui       if (c != 'y' && c != 'Y') {
154*01826a49SYabin Cui         state.errorHolder.setError("Not overwritten");
155*01826a49SYabin Cui         return nullptr;
156*01826a49SYabin Cui       }
157*01826a49SYabin Cui     }
158*01826a49SYabin Cui   }
159*01826a49SYabin Cui   auto outputFd = std::fopen(outputFile.c_str(), "wb");
160*01826a49SYabin Cui   if (!state.errorHolder.check(
161*01826a49SYabin Cui           outputFd != nullptr, "Failed to open output file")) {
162*01826a49SYabin Cui     return nullptr;
163*01826a49SYabin Cui   }
164*01826a49SYabin Cui   return outputFd;
165*01826a49SYabin Cui }
166*01826a49SYabin Cui 
pzstdMain(const Options & options)167*01826a49SYabin Cui int pzstdMain(const Options &options) {
168*01826a49SYabin Cui   int returnCode = 0;
169*01826a49SYabin Cui   SharedState state(options);
170*01826a49SYabin Cui   for (const auto& input : options.inputFiles) {
171*01826a49SYabin Cui     // Setup the shared state
172*01826a49SYabin Cui     auto printErrorGuard = makeScopeGuard([&] {
173*01826a49SYabin Cui       if (state.errorHolder.hasError()) {
174*01826a49SYabin Cui         returnCode = 1;
175*01826a49SYabin Cui         state.log(kLogError, "pzstd: %s: %s.\n", input.c_str(),
176*01826a49SYabin Cui                   state.errorHolder.getError().c_str());
177*01826a49SYabin Cui       }
178*01826a49SYabin Cui     });
179*01826a49SYabin Cui     // Open the input file
180*01826a49SYabin Cui     auto inputFd = openInputFile(input, state.errorHolder);
181*01826a49SYabin Cui     if (inputFd == nullptr) {
182*01826a49SYabin Cui       continue;
183*01826a49SYabin Cui     }
184*01826a49SYabin Cui     auto closeInputGuard = makeScopeGuard([&] { std::fclose(inputFd); });
185*01826a49SYabin Cui     // Open the output file
186*01826a49SYabin Cui     auto outputFile = options.getOutputFile(input);
187*01826a49SYabin Cui     if (!state.errorHolder.check(outputFile != "",
188*01826a49SYabin Cui                            "Input file does not have extension .zst")) {
189*01826a49SYabin Cui       continue;
190*01826a49SYabin Cui     }
191*01826a49SYabin Cui     auto outputFd = openOutputFile(options, outputFile, state);
192*01826a49SYabin Cui     if (outputFd == nullptr) {
193*01826a49SYabin Cui       continue;
194*01826a49SYabin Cui     }
195*01826a49SYabin Cui     auto closeOutputGuard = makeScopeGuard([&] { std::fclose(outputFd); });
196*01826a49SYabin Cui     // (de)compress the file
197*01826a49SYabin Cui     handleOneInput(options, input, inputFd, outputFile, outputFd, state);
198*01826a49SYabin Cui     if (state.errorHolder.hasError()) {
199*01826a49SYabin Cui       continue;
200*01826a49SYabin Cui     }
201*01826a49SYabin Cui     // Delete the input file if necessary
202*01826a49SYabin Cui     if (!options.keepSource) {
203*01826a49SYabin Cui       // Be sure that we are done and have written everything before we delete
204*01826a49SYabin Cui       if (!state.errorHolder.check(std::fclose(inputFd) == 0,
205*01826a49SYabin Cui                              "Failed to close input file")) {
206*01826a49SYabin Cui         continue;
207*01826a49SYabin Cui       }
208*01826a49SYabin Cui       closeInputGuard.dismiss();
209*01826a49SYabin Cui       if (!state.errorHolder.check(std::fclose(outputFd) == 0,
210*01826a49SYabin Cui                              "Failed to close output file")) {
211*01826a49SYabin Cui         continue;
212*01826a49SYabin Cui       }
213*01826a49SYabin Cui       closeOutputGuard.dismiss();
214*01826a49SYabin Cui       if (std::remove(input.c_str()) != 0) {
215*01826a49SYabin Cui         state.errorHolder.setError("Failed to remove input file");
216*01826a49SYabin Cui         continue;
217*01826a49SYabin Cui       }
218*01826a49SYabin Cui     }
219*01826a49SYabin Cui   }
220*01826a49SYabin Cui   // Returns 1 if any of the files failed to (de)compress.
221*01826a49SYabin Cui   return returnCode;
222*01826a49SYabin Cui }
223*01826a49SYabin Cui 
224*01826a49SYabin Cui /// Construct a `ZSTD_inBuffer` that points to the data in `buffer`.
makeZstdInBuffer(const Buffer & buffer)225*01826a49SYabin Cui static ZSTD_inBuffer makeZstdInBuffer(const Buffer& buffer) {
226*01826a49SYabin Cui   return ZSTD_inBuffer{buffer.data(), buffer.size(), 0};
227*01826a49SYabin Cui }
228*01826a49SYabin Cui 
229*01826a49SYabin Cui /**
230*01826a49SYabin Cui  * Advance `buffer` and `inBuffer` by the amount of data read, as indicated by
231*01826a49SYabin Cui  * `inBuffer.pos`.
232*01826a49SYabin Cui  */
advance(Buffer & buffer,ZSTD_inBuffer & inBuffer)233*01826a49SYabin Cui void advance(Buffer& buffer, ZSTD_inBuffer& inBuffer) {
234*01826a49SYabin Cui   auto pos = inBuffer.pos;
235*01826a49SYabin Cui   inBuffer.src = static_cast<const unsigned char*>(inBuffer.src) + pos;
236*01826a49SYabin Cui   inBuffer.size -= pos;
237*01826a49SYabin Cui   inBuffer.pos = 0;
238*01826a49SYabin Cui   return buffer.advance(pos);
239*01826a49SYabin Cui }
240*01826a49SYabin Cui 
241*01826a49SYabin Cui /// Construct a `ZSTD_outBuffer` that points to the data in `buffer`.
makeZstdOutBuffer(Buffer & buffer)242*01826a49SYabin Cui static ZSTD_outBuffer makeZstdOutBuffer(Buffer& buffer) {
243*01826a49SYabin Cui   return ZSTD_outBuffer{buffer.data(), buffer.size(), 0};
244*01826a49SYabin Cui }
245*01826a49SYabin Cui 
246*01826a49SYabin Cui /**
247*01826a49SYabin Cui  * Split `buffer` and advance `outBuffer` by the amount of data written, as
248*01826a49SYabin Cui  * indicated by `outBuffer.pos`.
249*01826a49SYabin Cui  */
split(Buffer & buffer,ZSTD_outBuffer & outBuffer)250*01826a49SYabin Cui Buffer split(Buffer& buffer, ZSTD_outBuffer& outBuffer) {
251*01826a49SYabin Cui   auto pos = outBuffer.pos;
252*01826a49SYabin Cui   outBuffer.dst = static_cast<unsigned char*>(outBuffer.dst) + pos;
253*01826a49SYabin Cui   outBuffer.size -= pos;
254*01826a49SYabin Cui   outBuffer.pos = 0;
255*01826a49SYabin Cui   return buffer.splitAt(pos);
256*01826a49SYabin Cui }
257*01826a49SYabin Cui 
258*01826a49SYabin Cui /**
259*01826a49SYabin Cui  * Stream chunks of input from `in`, compress it, and stream it out to `out`.
260*01826a49SYabin Cui  *
261*01826a49SYabin Cui  * @param state        The shared state
262*01826a49SYabin Cui  * @param in           Queue that we `pop()` input buffers from
263*01826a49SYabin Cui  * @param out          Queue that we `push()` compressed output buffers to
264*01826a49SYabin Cui  * @param maxInputSize An upper bound on the size of the input
265*01826a49SYabin Cui  */
compress(SharedState & state,std::shared_ptr<BufferWorkQueue> in,std::shared_ptr<BufferWorkQueue> out,size_t maxInputSize)266*01826a49SYabin Cui static void compress(
267*01826a49SYabin Cui     SharedState& state,
268*01826a49SYabin Cui     std::shared_ptr<BufferWorkQueue> in,
269*01826a49SYabin Cui     std::shared_ptr<BufferWorkQueue> out,
270*01826a49SYabin Cui     size_t maxInputSize) {
271*01826a49SYabin Cui   auto& errorHolder = state.errorHolder;
272*01826a49SYabin Cui   auto guard = makeScopeGuard([&] { out->finish(); });
273*01826a49SYabin Cui   // Initialize the CCtx
274*01826a49SYabin Cui   auto ctx = state.cStreamPool->get();
275*01826a49SYabin Cui   if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_CStream")) {
276*01826a49SYabin Cui     return;
277*01826a49SYabin Cui   }
278*01826a49SYabin Cui   {
279*01826a49SYabin Cui     auto err = ZSTD_CCtx_reset(ctx.get(), ZSTD_reset_session_only);
280*01826a49SYabin Cui     if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
281*01826a49SYabin Cui       return;
282*01826a49SYabin Cui     }
283*01826a49SYabin Cui   }
284*01826a49SYabin Cui 
285*01826a49SYabin Cui   // Allocate space for the result
286*01826a49SYabin Cui   auto outBuffer = Buffer(ZSTD_compressBound(maxInputSize));
287*01826a49SYabin Cui   auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
288*01826a49SYabin Cui   {
289*01826a49SYabin Cui     Buffer inBuffer;
290*01826a49SYabin Cui     // Read a buffer in from the input queue
291*01826a49SYabin Cui     while (in->pop(inBuffer) && !errorHolder.hasError()) {
292*01826a49SYabin Cui       auto zstdInBuffer = makeZstdInBuffer(inBuffer);
293*01826a49SYabin Cui       // Compress the whole buffer and send it to the output queue
294*01826a49SYabin Cui       while (!inBuffer.empty() && !errorHolder.hasError()) {
295*01826a49SYabin Cui         if (!errorHolder.check(
296*01826a49SYabin Cui                 !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
297*01826a49SYabin Cui           return;
298*01826a49SYabin Cui         }
299*01826a49SYabin Cui         // Compress
300*01826a49SYabin Cui         auto err =
301*01826a49SYabin Cui             ZSTD_compressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
302*01826a49SYabin Cui         if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
303*01826a49SYabin Cui           return;
304*01826a49SYabin Cui         }
305*01826a49SYabin Cui         // Split the compressed data off outBuffer and pass to the output queue
306*01826a49SYabin Cui         out->push(split(outBuffer, zstdOutBuffer));
307*01826a49SYabin Cui         // Forget about the data we already compressed
308*01826a49SYabin Cui         advance(inBuffer, zstdInBuffer);
309*01826a49SYabin Cui       }
310*01826a49SYabin Cui     }
311*01826a49SYabin Cui   }
312*01826a49SYabin Cui   // Write the epilog
313*01826a49SYabin Cui   size_t bytesLeft;
314*01826a49SYabin Cui   do {
315*01826a49SYabin Cui     if (!errorHolder.check(
316*01826a49SYabin Cui             !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
317*01826a49SYabin Cui       return;
318*01826a49SYabin Cui     }
319*01826a49SYabin Cui     bytesLeft = ZSTD_endStream(ctx.get(), &zstdOutBuffer);
320*01826a49SYabin Cui     if (!errorHolder.check(
321*01826a49SYabin Cui             !ZSTD_isError(bytesLeft), ZSTD_getErrorName(bytesLeft))) {
322*01826a49SYabin Cui       return;
323*01826a49SYabin Cui     }
324*01826a49SYabin Cui     out->push(split(outBuffer, zstdOutBuffer));
325*01826a49SYabin Cui   } while (bytesLeft != 0 && !errorHolder.hasError());
326*01826a49SYabin Cui }
327*01826a49SYabin Cui 
328*01826a49SYabin Cui /**
329*01826a49SYabin Cui  * Calculates how large each independently compressed frame should be.
330*01826a49SYabin Cui  *
331*01826a49SYabin Cui  * @param size       The size of the source if known, 0 otherwise
332*01826a49SYabin Cui  * @param numThreads The number of threads available to run compression jobs on
333*01826a49SYabin Cui  * @param params     The zstd parameters to be used for compression
334*01826a49SYabin Cui  */
calculateStep(std::uintmax_t size,size_t numThreads,const ZSTD_parameters & params)335*01826a49SYabin Cui static size_t calculateStep(
336*01826a49SYabin Cui     std::uintmax_t size,
337*01826a49SYabin Cui     size_t numThreads,
338*01826a49SYabin Cui     const ZSTD_parameters &params) {
339*01826a49SYabin Cui   (void)size;
340*01826a49SYabin Cui   (void)numThreads;
341*01826a49SYabin Cui   // Not validated to work correctly for window logs > 23.
342*01826a49SYabin Cui   // It will definitely fail if windowLog + 2 is >= 4GB because
343*01826a49SYabin Cui   // the skippable frame can only store sizes up to 4GB.
344*01826a49SYabin Cui   assert(params.cParams.windowLog <= 23);
345*01826a49SYabin Cui   return size_t{1} << (params.cParams.windowLog + 2);
346*01826a49SYabin Cui }
347*01826a49SYabin Cui 
348*01826a49SYabin Cui namespace {
349*01826a49SYabin Cui enum class FileStatus { Continue, Done, Error };
350*01826a49SYabin Cui /// Determines the status of the file descriptor `fd`.
fileStatus(FILE * fd)351*01826a49SYabin Cui FileStatus fileStatus(FILE* fd) {
352*01826a49SYabin Cui   if (std::feof(fd)) {
353*01826a49SYabin Cui     return FileStatus::Done;
354*01826a49SYabin Cui   } else if (std::ferror(fd)) {
355*01826a49SYabin Cui     return FileStatus::Error;
356*01826a49SYabin Cui   }
357*01826a49SYabin Cui   return FileStatus::Continue;
358*01826a49SYabin Cui }
359*01826a49SYabin Cui } // anonymous namespace
360*01826a49SYabin Cui 
361*01826a49SYabin Cui /**
362*01826a49SYabin Cui  * Reads `size` data in chunks of `chunkSize` and puts it into `queue`.
363*01826a49SYabin Cui  * Will read less if an error or EOF occurs.
364*01826a49SYabin Cui  * Returns the status of the file after all of the reads have occurred.
365*01826a49SYabin Cui  */
366*01826a49SYabin Cui static FileStatus
readData(BufferWorkQueue & queue,size_t chunkSize,size_t size,FILE * fd,std::uint64_t * totalBytesRead)367*01826a49SYabin Cui readData(BufferWorkQueue& queue, size_t chunkSize, size_t size, FILE* fd,
368*01826a49SYabin Cui          std::uint64_t *totalBytesRead) {
369*01826a49SYabin Cui   Buffer buffer(size);
370*01826a49SYabin Cui   while (!buffer.empty()) {
371*01826a49SYabin Cui     auto bytesRead =
372*01826a49SYabin Cui         std::fread(buffer.data(), 1, std::min(chunkSize, buffer.size()), fd);
373*01826a49SYabin Cui     *totalBytesRead += bytesRead;
374*01826a49SYabin Cui     queue.push(buffer.splitAt(bytesRead));
375*01826a49SYabin Cui     auto status = fileStatus(fd);
376*01826a49SYabin Cui     if (status != FileStatus::Continue) {
377*01826a49SYabin Cui       return status;
378*01826a49SYabin Cui     }
379*01826a49SYabin Cui   }
380*01826a49SYabin Cui   return FileStatus::Continue;
381*01826a49SYabin Cui }
382*01826a49SYabin Cui 
asyncCompressChunks(SharedState & state,WorkQueue<std::shared_ptr<BufferWorkQueue>> & chunks,ThreadPool & executor,FILE * fd,std::uintmax_t size,size_t numThreads,ZSTD_parameters params)383*01826a49SYabin Cui std::uint64_t asyncCompressChunks(
384*01826a49SYabin Cui     SharedState& state,
385*01826a49SYabin Cui     WorkQueue<std::shared_ptr<BufferWorkQueue>>& chunks,
386*01826a49SYabin Cui     ThreadPool& executor,
387*01826a49SYabin Cui     FILE* fd,
388*01826a49SYabin Cui     std::uintmax_t size,
389*01826a49SYabin Cui     size_t numThreads,
390*01826a49SYabin Cui     ZSTD_parameters params) {
391*01826a49SYabin Cui   auto chunksGuard = makeScopeGuard([&] { chunks.finish(); });
392*01826a49SYabin Cui   std::uint64_t bytesRead = 0;
393*01826a49SYabin Cui 
394*01826a49SYabin Cui   // Break the input up into chunks of size `step` and compress each chunk
395*01826a49SYabin Cui   // independently.
396*01826a49SYabin Cui   size_t step = calculateStep(size, numThreads, params);
397*01826a49SYabin Cui   state.log(kLogDebug, "Chosen frame size: %zu\n", step);
398*01826a49SYabin Cui   auto status = FileStatus::Continue;
399*01826a49SYabin Cui   while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
400*01826a49SYabin Cui     // Make a new input queue that we will put the chunk's input data into.
401*01826a49SYabin Cui     auto in = std::make_shared<BufferWorkQueue>();
402*01826a49SYabin Cui     auto inGuard = makeScopeGuard([&] { in->finish(); });
403*01826a49SYabin Cui     // Make a new output queue that compress will put the compressed data into.
404*01826a49SYabin Cui     auto out = std::make_shared<BufferWorkQueue>();
405*01826a49SYabin Cui     // Start compression in the thread pool
406*01826a49SYabin Cui     executor.add([&state, in, out, step] {
407*01826a49SYabin Cui       return compress(
408*01826a49SYabin Cui           state, std::move(in), std::move(out), step);
409*01826a49SYabin Cui     });
410*01826a49SYabin Cui     // Pass the output queue to the writer thread.
411*01826a49SYabin Cui     chunks.push(std::move(out));
412*01826a49SYabin Cui     state.log(kLogVerbose, "%s\n", "Starting a new frame");
413*01826a49SYabin Cui     // Fill the input queue for the compression job we just started
414*01826a49SYabin Cui     status = readData(*in, ZSTD_CStreamInSize(), step, fd, &bytesRead);
415*01826a49SYabin Cui   }
416*01826a49SYabin Cui   state.errorHolder.check(status != FileStatus::Error, "Error reading input");
417*01826a49SYabin Cui   return bytesRead;
418*01826a49SYabin Cui }
419*01826a49SYabin Cui 
420*01826a49SYabin Cui /**
421*01826a49SYabin Cui  * Decompress a frame, whose data is streamed into `in`, and stream the output
422*01826a49SYabin Cui  * to `out`.
423*01826a49SYabin Cui  *
424*01826a49SYabin Cui  * @param state        The shared state
425*01826a49SYabin Cui  * @param in           Queue that we `pop()` input buffers from. It contains
426*01826a49SYabin Cui  *                      exactly one compressed frame.
427*01826a49SYabin Cui  * @param out          Queue that we `push()` decompressed output buffers to
428*01826a49SYabin Cui  */
decompress(SharedState & state,std::shared_ptr<BufferWorkQueue> in,std::shared_ptr<BufferWorkQueue> out)429*01826a49SYabin Cui static void decompress(
430*01826a49SYabin Cui     SharedState& state,
431*01826a49SYabin Cui     std::shared_ptr<BufferWorkQueue> in,
432*01826a49SYabin Cui     std::shared_ptr<BufferWorkQueue> out) {
433*01826a49SYabin Cui   auto& errorHolder = state.errorHolder;
434*01826a49SYabin Cui   auto guard = makeScopeGuard([&] { out->finish(); });
435*01826a49SYabin Cui   // Initialize the DCtx
436*01826a49SYabin Cui   auto ctx = state.dStreamPool->get();
437*01826a49SYabin Cui   if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_DStream")) {
438*01826a49SYabin Cui     return;
439*01826a49SYabin Cui   }
440*01826a49SYabin Cui   {
441*01826a49SYabin Cui     auto err = ZSTD_DCtx_reset(ctx.get(), ZSTD_reset_session_only);
442*01826a49SYabin Cui     if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
443*01826a49SYabin Cui       return;
444*01826a49SYabin Cui     }
445*01826a49SYabin Cui   }
446*01826a49SYabin Cui 
447*01826a49SYabin Cui   const size_t outSize = ZSTD_DStreamOutSize();
448*01826a49SYabin Cui   Buffer inBuffer;
449*01826a49SYabin Cui   size_t returnCode = 0;
450*01826a49SYabin Cui   // Read a buffer in from the input queue
451*01826a49SYabin Cui   while (in->pop(inBuffer) && !errorHolder.hasError()) {
452*01826a49SYabin Cui     auto zstdInBuffer = makeZstdInBuffer(inBuffer);
453*01826a49SYabin Cui     // Decompress the whole buffer and send it to the output queue
454*01826a49SYabin Cui     while (!inBuffer.empty() && !errorHolder.hasError()) {
455*01826a49SYabin Cui       // Allocate a buffer with at least outSize bytes.
456*01826a49SYabin Cui       Buffer outBuffer(outSize);
457*01826a49SYabin Cui       auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
458*01826a49SYabin Cui       // Decompress
459*01826a49SYabin Cui       returnCode =
460*01826a49SYabin Cui           ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
461*01826a49SYabin Cui       if (!errorHolder.check(
462*01826a49SYabin Cui               !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
463*01826a49SYabin Cui         return;
464*01826a49SYabin Cui       }
465*01826a49SYabin Cui       // Pass the buffer with the decompressed data to the output queue
466*01826a49SYabin Cui       out->push(split(outBuffer, zstdOutBuffer));
467*01826a49SYabin Cui       // Advance past the input we already read
468*01826a49SYabin Cui       advance(inBuffer, zstdInBuffer);
469*01826a49SYabin Cui       if (returnCode == 0) {
470*01826a49SYabin Cui         // The frame is over, prepare to (maybe) start a new frame
471*01826a49SYabin Cui         ZSTD_initDStream(ctx.get());
472*01826a49SYabin Cui       }
473*01826a49SYabin Cui     }
474*01826a49SYabin Cui   }
475*01826a49SYabin Cui   if (!errorHolder.check(returnCode <= 1, "Incomplete block")) {
476*01826a49SYabin Cui     return;
477*01826a49SYabin Cui   }
478*01826a49SYabin Cui   // We've given ZSTD_decompressStream all of our data, but there may still
479*01826a49SYabin Cui   // be data to read.
480*01826a49SYabin Cui   while (returnCode == 1) {
481*01826a49SYabin Cui     // Allocate a buffer with at least outSize bytes.
482*01826a49SYabin Cui     Buffer outBuffer(outSize);
483*01826a49SYabin Cui     auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
484*01826a49SYabin Cui     // Pass in no input.
485*01826a49SYabin Cui     ZSTD_inBuffer zstdInBuffer{nullptr, 0, 0};
486*01826a49SYabin Cui     // Decompress
487*01826a49SYabin Cui     returnCode =
488*01826a49SYabin Cui         ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
489*01826a49SYabin Cui     if (!errorHolder.check(
490*01826a49SYabin Cui             !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
491*01826a49SYabin Cui       return;
492*01826a49SYabin Cui     }
493*01826a49SYabin Cui     // Pass the buffer with the decompressed data to the output queue
494*01826a49SYabin Cui     out->push(split(outBuffer, zstdOutBuffer));
495*01826a49SYabin Cui   }
496*01826a49SYabin Cui }
497*01826a49SYabin Cui 
asyncDecompressFrames(SharedState & state,WorkQueue<std::shared_ptr<BufferWorkQueue>> & frames,ThreadPool & executor,FILE * fd)498*01826a49SYabin Cui std::uint64_t asyncDecompressFrames(
499*01826a49SYabin Cui     SharedState& state,
500*01826a49SYabin Cui     WorkQueue<std::shared_ptr<BufferWorkQueue>>& frames,
501*01826a49SYabin Cui     ThreadPool& executor,
502*01826a49SYabin Cui     FILE* fd) {
503*01826a49SYabin Cui   auto framesGuard = makeScopeGuard([&] { frames.finish(); });
504*01826a49SYabin Cui   std::uint64_t totalBytesRead = 0;
505*01826a49SYabin Cui 
506*01826a49SYabin Cui   // Split the source up into its component frames.
507*01826a49SYabin Cui   // If we find our recognized skippable frame we know the next frames size
508*01826a49SYabin Cui   // which means that we can decompress each standard frame in independently.
509*01826a49SYabin Cui   // Otherwise, we will decompress using only one decompression task.
510*01826a49SYabin Cui   const size_t chunkSize = ZSTD_DStreamInSize();
511*01826a49SYabin Cui   auto status = FileStatus::Continue;
512*01826a49SYabin Cui   while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
513*01826a49SYabin Cui     // Make a new input queue that we will put the frames's bytes into.
514*01826a49SYabin Cui     auto in = std::make_shared<BufferWorkQueue>();
515*01826a49SYabin Cui     auto inGuard = makeScopeGuard([&] { in->finish(); });
516*01826a49SYabin Cui     // Make a output queue that decompress will put the decompressed data into
517*01826a49SYabin Cui     auto out = std::make_shared<BufferWorkQueue>();
518*01826a49SYabin Cui 
519*01826a49SYabin Cui     size_t frameSize;
520*01826a49SYabin Cui     {
521*01826a49SYabin Cui       // Calculate the size of the next frame.
522*01826a49SYabin Cui       // frameSize is 0 if the frame info can't be decoded.
523*01826a49SYabin Cui       Buffer buffer(SkippableFrame::kSize);
524*01826a49SYabin Cui       auto bytesRead = std::fread(buffer.data(), 1, buffer.size(), fd);
525*01826a49SYabin Cui       totalBytesRead += bytesRead;
526*01826a49SYabin Cui       status = fileStatus(fd);
527*01826a49SYabin Cui       if (bytesRead == 0 && status != FileStatus::Continue) {
528*01826a49SYabin Cui         break;
529*01826a49SYabin Cui       }
530*01826a49SYabin Cui       buffer.subtract(buffer.size() - bytesRead);
531*01826a49SYabin Cui       frameSize = SkippableFrame::tryRead(buffer.range());
532*01826a49SYabin Cui       in->push(std::move(buffer));
533*01826a49SYabin Cui     }
534*01826a49SYabin Cui     if (frameSize == 0) {
535*01826a49SYabin Cui       // We hit a non SkippableFrame, so this will be the last job.
536*01826a49SYabin Cui       // Make sure that we don't use too much memory
537*01826a49SYabin Cui       in->setMaxSize(64);
538*01826a49SYabin Cui       out->setMaxSize(64);
539*01826a49SYabin Cui     }
540*01826a49SYabin Cui     // Start decompression in the thread pool
541*01826a49SYabin Cui     executor.add([&state, in, out] {
542*01826a49SYabin Cui       return decompress(state, std::move(in), std::move(out));
543*01826a49SYabin Cui     });
544*01826a49SYabin Cui     // Pass the output queue to the writer thread
545*01826a49SYabin Cui     frames.push(std::move(out));
546*01826a49SYabin Cui     if (frameSize == 0) {
547*01826a49SYabin Cui       // We hit a non SkippableFrame ==> not compressed by pzstd or corrupted
548*01826a49SYabin Cui       // Pass the rest of the source to this decompression task
549*01826a49SYabin Cui       state.log(kLogVerbose, "%s\n",
550*01826a49SYabin Cui           "Input not in pzstd format, falling back to serial decompression");
551*01826a49SYabin Cui       while (status == FileStatus::Continue && !state.errorHolder.hasError()) {
552*01826a49SYabin Cui         status = readData(*in, chunkSize, chunkSize, fd, &totalBytesRead);
553*01826a49SYabin Cui       }
554*01826a49SYabin Cui       break;
555*01826a49SYabin Cui     }
556*01826a49SYabin Cui     state.log(kLogVerbose, "Decompressing a frame of size %zu", frameSize);
557*01826a49SYabin Cui     // Fill the input queue for the decompression job we just started
558*01826a49SYabin Cui     status = readData(*in, chunkSize, frameSize, fd, &totalBytesRead);
559*01826a49SYabin Cui   }
560*01826a49SYabin Cui   state.errorHolder.check(status != FileStatus::Error, "Error reading input");
561*01826a49SYabin Cui   return totalBytesRead;
562*01826a49SYabin Cui }
563*01826a49SYabin Cui 
564*01826a49SYabin Cui /// Write `data` to `fd`, returns true iff success.
writeData(ByteRange data,FILE * fd)565*01826a49SYabin Cui static bool writeData(ByteRange data, FILE* fd) {
566*01826a49SYabin Cui   while (!data.empty()) {
567*01826a49SYabin Cui     data.advance(std::fwrite(data.begin(), 1, data.size(), fd));
568*01826a49SYabin Cui     if (std::ferror(fd)) {
569*01826a49SYabin Cui       return false;
570*01826a49SYabin Cui     }
571*01826a49SYabin Cui   }
572*01826a49SYabin Cui   return true;
573*01826a49SYabin Cui }
574*01826a49SYabin Cui 
writeFile(SharedState & state,WorkQueue<std::shared_ptr<BufferWorkQueue>> & outs,FILE * outputFd,bool decompress)575*01826a49SYabin Cui std::uint64_t writeFile(
576*01826a49SYabin Cui     SharedState& state,
577*01826a49SYabin Cui     WorkQueue<std::shared_ptr<BufferWorkQueue>>& outs,
578*01826a49SYabin Cui     FILE* outputFd,
579*01826a49SYabin Cui     bool decompress) {
580*01826a49SYabin Cui   auto& errorHolder = state.errorHolder;
581*01826a49SYabin Cui   auto lineClearGuard = makeScopeGuard([&state] {
582*01826a49SYabin Cui     state.log.clear(kLogInfo);
583*01826a49SYabin Cui   });
584*01826a49SYabin Cui   std::uint64_t bytesWritten = 0;
585*01826a49SYabin Cui   std::shared_ptr<BufferWorkQueue> out;
586*01826a49SYabin Cui   // Grab the output queue for each decompression job (in order).
587*01826a49SYabin Cui   while (outs.pop(out)) {
588*01826a49SYabin Cui     if (errorHolder.hasError()) {
589*01826a49SYabin Cui       continue;
590*01826a49SYabin Cui     }
591*01826a49SYabin Cui     if (!decompress) {
592*01826a49SYabin Cui       // If we are compressing and want to write skippable frames we can't
593*01826a49SYabin Cui       // start writing before compression is done because we need to know the
594*01826a49SYabin Cui       // compressed size.
595*01826a49SYabin Cui       // Wait for the compressed size to be available and write skippable frame
596*01826a49SYabin Cui       assert(uint64_t(out->size()) < uint64_t(1) << 32);
597*01826a49SYabin Cui       SkippableFrame frame(uint32_t(out->size()));
598*01826a49SYabin Cui       if (!writeData(frame.data(), outputFd)) {
599*01826a49SYabin Cui         errorHolder.setError("Failed to write output");
600*01826a49SYabin Cui         return bytesWritten;
601*01826a49SYabin Cui       }
602*01826a49SYabin Cui       bytesWritten += frame.kSize;
603*01826a49SYabin Cui     }
604*01826a49SYabin Cui     // For each chunk of the frame: Pop it from the queue and write it
605*01826a49SYabin Cui     Buffer buffer;
606*01826a49SYabin Cui     while (out->pop(buffer) && !errorHolder.hasError()) {
607*01826a49SYabin Cui       if (!writeData(buffer.range(), outputFd)) {
608*01826a49SYabin Cui         errorHolder.setError("Failed to write output");
609*01826a49SYabin Cui         return bytesWritten;
610*01826a49SYabin Cui       }
611*01826a49SYabin Cui       bytesWritten += buffer.size();
612*01826a49SYabin Cui       state.log.update(kLogInfo, "Written: %u MB   ",
613*01826a49SYabin Cui                 static_cast<std::uint32_t>(bytesWritten >> 20));
614*01826a49SYabin Cui     }
615*01826a49SYabin Cui   }
616*01826a49SYabin Cui   return bytesWritten;
617*01826a49SYabin Cui }
618*01826a49SYabin Cui }
619