1*ec63e07aSXin Li // Copyright 2022 Google LLC
2*ec63e07aSXin Li //
3*ec63e07aSXin Li // Licensed under the Apache License, Version 2.0 (the "License");
4*ec63e07aSXin Li // you may not use this file except in compliance with the License.
5*ec63e07aSXin Li // You may obtain a copy of the License at
6*ec63e07aSXin Li //
7*ec63e07aSXin Li // https://www.apache.org/licenses/LICENSE-2.0
8*ec63e07aSXin Li //
9*ec63e07aSXin Li // Unless required by applicable law or agreed to in writing, software
10*ec63e07aSXin Li // distributed under the License is distributed on an "AS IS" BASIS,
11*ec63e07aSXin Li // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*ec63e07aSXin Li // See the License for the specific language governing permissions and
13*ec63e07aSXin Li // limitations under the License.
14*ec63e07aSXin Li
15*ec63e07aSXin Li #include <fcntl.h>
16*ec63e07aSXin Li #include <unistd.h>
17*ec63e07aSXin Li
18*ec63e07aSXin Li #include <cstdlib>
19*ec63e07aSXin Li #include <fstream>
20*ec63e07aSXin Li #include <iostream>
21*ec63e07aSXin Li #include <string>
22*ec63e07aSXin Li #include <vector>
23*ec63e07aSXin Li
24*ec63e07aSXin Li #include "absl/flags/flag.h"
25*ec63e07aSXin Li #include "absl/flags/parse.h"
26*ec63e07aSXin Li #include "absl/log/globals.h"
27*ec63e07aSXin Li #include "absl/log/initialize.h"
28*ec63e07aSXin Li #include "contrib/zstd/sandboxed.h"
29*ec63e07aSXin Li #include "contrib/zstd/utils/utils_zstd.h"
30*ec63e07aSXin Li
31*ec63e07aSXin Li ABSL_FLAG(bool, stream, false, "stream data to sandbox");
32*ec63e07aSXin Li ABSL_FLAG(bool, decompress, false, "decompress");
33*ec63e07aSXin Li ABSL_FLAG(bool, memory_mode, false, "in memory operations");
34*ec63e07aSXin Li ABSL_FLAG(uint32_t, level, 0, "compression level");
35*ec63e07aSXin Li
Stream(ZstdApi & api,std::string infile_s,std::string outfile_s)36*ec63e07aSXin Li absl::Status Stream(ZstdApi& api, std::string infile_s, std::string outfile_s) {
37*ec63e07aSXin Li std::ifstream infile(infile_s, std::ios::binary);
38*ec63e07aSXin Li if (!infile.is_open()) {
39*ec63e07aSXin Li return absl::UnavailableError(absl::StrCat("Unable to open ", infile_s));
40*ec63e07aSXin Li }
41*ec63e07aSXin Li std::ofstream outfile(outfile_s, std::ios::binary);
42*ec63e07aSXin Li if (!outfile.is_open()) {
43*ec63e07aSXin Li return absl::UnavailableError(absl::StrCat("Unable to open ", outfile_s));
44*ec63e07aSXin Li }
45*ec63e07aSXin Li
46*ec63e07aSXin Li if (absl::GetFlag(FLAGS_memory_mode)) {
47*ec63e07aSXin Li if (absl::GetFlag(FLAGS_decompress)) {
48*ec63e07aSXin Li return DecompressInMemory(api, infile, outfile);
49*ec63e07aSXin Li }
50*ec63e07aSXin Li return CompressInMemory(api, infile, outfile, absl::GetFlag(FLAGS_level));
51*ec63e07aSXin Li }
52*ec63e07aSXin Li if (absl::GetFlag(FLAGS_decompress)) {
53*ec63e07aSXin Li return DecompressStream(api, infile, outfile);
54*ec63e07aSXin Li }
55*ec63e07aSXin Li return CompressStream(api, infile, outfile, absl::GetFlag(FLAGS_level));
56*ec63e07aSXin Li }
57*ec63e07aSXin Li
FileDescriptor(ZstdApi & api,std::string infile_s,std::string outfile_s)58*ec63e07aSXin Li absl::Status FileDescriptor(ZstdApi& api, std::string infile_s,
59*ec63e07aSXin Li std::string outfile_s) {
60*ec63e07aSXin Li sapi::v::Fd infd(open(infile_s.c_str(), O_RDONLY));
61*ec63e07aSXin Li if (infd.GetValue() < 0) {
62*ec63e07aSXin Li return absl::UnavailableError(absl::StrCat("Unable to open ", infile_s));
63*ec63e07aSXin Li }
64*ec63e07aSXin Li sapi::v::Fd outfd(open(outfile_s.c_str(), O_WRONLY | O_CREAT));
65*ec63e07aSXin Li if (outfd.GetValue() < 0) {
66*ec63e07aSXin Li return absl::UnavailableError(absl::StrCat("Unable to open ", outfile_s));
67*ec63e07aSXin Li }
68*ec63e07aSXin Li
69*ec63e07aSXin Li if (absl::GetFlag(FLAGS_memory_mode)) {
70*ec63e07aSXin Li if (absl::GetFlag(FLAGS_decompress)) {
71*ec63e07aSXin Li return DecompressInMemoryFD(api, infd, outfd);
72*ec63e07aSXin Li }
73*ec63e07aSXin Li return CompressInMemoryFD(api, infd, outfd, absl::GetFlag(FLAGS_level));
74*ec63e07aSXin Li }
75*ec63e07aSXin Li if (absl::GetFlag(FLAGS_decompress)) {
76*ec63e07aSXin Li return DecompressStreamFD(api, infd, outfd);
77*ec63e07aSXin Li }
78*ec63e07aSXin Li return CompressStreamFD(api, infd, outfd, absl::GetFlag(FLAGS_level));
79*ec63e07aSXin Li }
80*ec63e07aSXin Li
main(int argc,char * argv[])81*ec63e07aSXin Li int main(int argc, char* argv[]) {
82*ec63e07aSXin Li std::string prog_name(argv[0]);
83*ec63e07aSXin Li absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);
84*ec63e07aSXin Li std::vector<char*> args = absl::ParseCommandLine(argc, argv);
85*ec63e07aSXin Li absl::InitializeLog();
86*ec63e07aSXin Li
87*ec63e07aSXin Li if (args.size() != 3) {
88*ec63e07aSXin Li std::cerr << "Usage:\n " << prog_name << " INPUT OUTPUT\n";
89*ec63e07aSXin Li return EXIT_FAILURE;
90*ec63e07aSXin Li }
91*ec63e07aSXin Li
92*ec63e07aSXin Li ZstdSapiSandbox sandbox;
93*ec63e07aSXin Li if (!sandbox.Init().ok()) {
94*ec63e07aSXin Li std::cerr << "Unable to start sandbox\n";
95*ec63e07aSXin Li return EXIT_FAILURE;
96*ec63e07aSXin Li }
97*ec63e07aSXin Li
98*ec63e07aSXin Li ZstdApi api(&sandbox);
99*ec63e07aSXin Li
100*ec63e07aSXin Li absl::Status status;
101*ec63e07aSXin Li if (absl::GetFlag(FLAGS_stream)) {
102*ec63e07aSXin Li status = Stream(api, argv[1], argv[2]);
103*ec63e07aSXin Li } else {
104*ec63e07aSXin Li status = FileDescriptor(api, argv[1], argv[2]);
105*ec63e07aSXin Li }
106*ec63e07aSXin Li
107*ec63e07aSXin Li if (!status.ok()) {
108*ec63e07aSXin Li std::cerr << "Unable to ";
109*ec63e07aSXin Li std::cerr << (absl::GetFlag(FLAGS_decompress) ? "decompress" : "compress");
110*ec63e07aSXin Li std::cerr << " file.\n" << status << "\n";
111*ec63e07aSXin Li return EXIT_FAILURE;
112*ec63e07aSXin Li }
113*ec63e07aSXin Li
114*ec63e07aSXin Li return EXIT_SUCCESS;
115*ec63e07aSXin Li }
116