1 /*
2 * Copyright 2019 Google LLC.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * https://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "private_join_and_compute/util/recordio.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/log/log.h"
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/synchronization/mutex.h"
33 #include "private_join_and_compute/util/status.inc"
34 #include "src/google/protobuf/io/coded_stream.h"
35 #include "src/google/protobuf/io/zero_copy_stream_impl_lite.h"
36
37 namespace private_join_and_compute {
38
39 namespace {
40
41 // Max. size of a Varint32 (from proto references).
42 const uint32_t kMaxVarint32Size = 5;
43
44 // Tries to read a Varint32 from the front of a given file. Returns false if the
45 // reading fails.
ExtractVarint32(File * file)46 StatusOr<uint32_t> ExtractVarint32(File* file) {
47 // Keep reading a single character until one is found such that the top bit is
48 // 0;
49 std::string bytes_read = "";
50
51 size_t current_byte = 0;
52 ASSIGN_OR_RETURN(auto has_more, file->HasMore());
53 while (current_byte < kMaxVarint32Size && has_more) {
54 auto maybe_last_byte = file->Read(1);
55 if (!maybe_last_byte.ok()) {
56 return maybe_last_byte.status();
57 }
58
59 bytes_read += maybe_last_byte.value();
60 if (!(bytes_read.data()[current_byte] & 0x80)) {
61 break;
62 }
63 current_byte++;
64 // If we read the max number of bits and never found a "terminating" byte,
65 // return false.
66 if (current_byte >= kMaxVarint32Size) {
67 return InvalidArgumentError(
68 "ExtractVarint32: Failed to extract a Varint after reading max "
69 "number "
70 "of bytes.");
71 }
72 ASSIGN_OR_RETURN(has_more, file->HasMore());
73 }
74
75 google::protobuf::io::ArrayInputStream arrayInputStream(bytes_read.data(),
76 bytes_read.size());
77 google::protobuf::io::CodedInputStream codedInputStream(&arrayInputStream);
78 uint32_t result;
79 codedInputStream.ReadVarint32(&result);
80
81 return result;
82 }
83
84 // Reads records from a file one at a time.
85 class RecordReaderImpl : public RecordReader {
86 public:
RecordReaderImpl(File * file)87 explicit RecordReaderImpl(File* file) : RecordReader(), in_(file) {}
88
Open(absl::string_view filename)89 Status Open(absl::string_view filename) final {
90 return in_->Open(filename, "r");
91 }
92
Close()93 Status Close() final { return in_->Close(); }
94
HasMore()95 StatusOr<bool> HasMore() final {
96 auto status_or_has_more = in_->HasMore();
97 if (!status_or_has_more.ok()) {
98 LOG(ERROR) << status_or_has_more.status();
99 }
100 return status_or_has_more;
101 }
102
Read(std::string * raw_data)103 Status Read(std::string* raw_data) final {
104 raw_data->erase();
105 auto maybe_record_size = ExtractVarint32(in_.get());
106 if (!maybe_record_size.ok()) {
107 LOG(ERROR) << "RecordReader::Read: Couldn't read record size: "
108 << maybe_record_size.status();
109 return maybe_record_size.status();
110 }
111 uint32_t record_size = maybe_record_size.value();
112
113 auto status_or_data = in_->Read(record_size);
114 if (!status_or_data.ok()) {
115 LOG(ERROR) << status_or_data.status();
116 return status_or_data.status();
117 }
118
119 raw_data->append(status_or_data.value());
120 return OkStatus();
121 }
122
123 private:
124 std::unique_ptr<File> in_;
125 };
126
127 // Reads lines from a file one at a time.
128 class LineReader : public RecordReader {
129 public:
LineReader(File * file)130 explicit LineReader(File* file) : RecordReader(), in_(file) {}
131
Open(absl::string_view filename)132 Status Open(absl::string_view filename) final {
133 return in_->Open(filename, "r");
134 }
135
Close()136 Status Close() final { return in_->Close(); }
137
HasMore()138 StatusOr<bool> HasMore() final { return in_->HasMore(); }
139
Read(std::string * line)140 Status Read(std::string* line) final {
141 line->erase();
142 auto status_or_line = in_->ReadLine();
143 if (!status_or_line.ok()) {
144 LOG(ERROR) << status_or_line.status();
145 return status_or_line.status();
146 }
147 line->append(status_or_line.value());
148 return OkStatus();
149 }
150
151 private:
152 std::unique_ptr<File> in_;
153 };
154
155 template <typename T>
156 class MultiSortedReaderImpl : public MultiSortedReader<T> {
157 public:
MultiSortedReaderImpl(const std::function<RecordReader * ()> & get_reader,std::unique_ptr<std::function<T (absl::string_view)>> default_key=nullptr)158 explicit MultiSortedReaderImpl(
159 const std::function<RecordReader*()>& get_reader,
160 std::unique_ptr<std::function<T(absl::string_view)>> default_key =
161 nullptr)
162 : MultiSortedReader<T>(),
163 get_reader_(get_reader),
164 default_key_(std::move(default_key)),
165 key_(nullptr) {}
166
Open(const std::vector<std::string> & filenames)167 Status Open(const std::vector<std::string>& filenames) override {
168 if (default_key_ == nullptr) {
169 return InvalidArgumentError("The sorting key is null.");
170 }
171 return Open(filenames, *default_key_);
172 }
173
Open(const std::vector<std::string> & filenames,const std::function<T (absl::string_view)> & key)174 Status Open(const std::vector<std::string>& filenames,
175 const std::function<T(absl::string_view)>& key) override {
176 if (!readers_.empty()) {
177 return InternalError("There are files not closed, call Close() first.");
178 }
179 key_ = std::make_unique<std::function<T(absl::string_view)>>(key);
180 for (size_t i = 0; i < filenames.size(); ++i) {
181 this->readers_.push_back(std::unique_ptr<RecordReader>(get_reader_()));
182 auto open_status = this->readers_.back()->Open(filenames[i]);
183 if (!open_status.ok()) {
184 // Try to close the opened ones.
185 for (int j = i - 1; j >= 0; --j) {
186 // If closing fails as well, then any call to Open will fail as well
187 // since some of the files will remain opened.
188 auto status = this->readers_[j]->Close();
189 if (!status.ok()) {
190 LOG(ERROR) << "Error closing file " << status;
191 }
192 this->readers_.pop_back();
193 }
194 return open_status;
195 }
196 }
197 return OkStatus();
198 }
199
Close()200 Status Close() override {
201 Status status = OkStatus();
202 bool ret_val =
203 std::all_of(readers_.begin(), readers_.end(),
204 [&status](std::unique_ptr<RecordReader>& reader) {
205 Status close_status = reader->Close();
206 if (!close_status.ok()) {
207 status = close_status;
208 return false;
209 } else {
210 return true;
211 }
212 });
213 if (ret_val) {
214 readers_ = std::vector<std::unique_ptr<RecordReader>>();
215 min_heap_ = std::priority_queue<HeapData, std::vector<HeapData>,
216 HeapDataGreater>();
217 }
218 return status;
219 }
220
HasMore()221 StatusOr<bool> HasMore() override {
222 if (!min_heap_.empty()) {
223 return true;
224 }
225 Status status = OkStatus();
226 for (const auto& reader : readers_) {
227 auto status_or_has_more = reader->HasMore();
228 if (status_or_has_more.ok()) {
229 if (status_or_has_more.value()) {
230 return true;
231 }
232 } else {
233 status = status_or_has_more.status();
234 }
235 }
236 if (status.ok()) {
237 // None of the readers has more.
238 return false;
239 }
240 return status;
241 }
242
Read(std::string * data)243 Status Read(std::string* data) override { return Read(data, nullptr); }
244
Read(std::string * data,int * index)245 Status Read(std::string* data, int* index) override {
246 if (min_heap_.empty()) {
247 for (size_t i = 0; i < readers_.size(); ++i) {
248 RETURN_IF_ERROR(this->ReadHeapDataFromReader(i));
249 }
250 }
251 HeapData ret_data = min_heap_.top();
252 data->assign(ret_data.data);
253 if (index != nullptr) *index = ret_data.index;
254 min_heap_.pop();
255 return this->ReadHeapDataFromReader(ret_data.index);
256 }
257
258 private:
ReadHeapDataFromReader(int index)259 Status ReadHeapDataFromReader(int index) {
260 std::string data;
261 auto status_or_has_more = readers_[index]->HasMore();
262 if (!status_or_has_more.ok()) {
263 return status_or_has_more.status();
264 }
265 if (status_or_has_more.value()) {
266 RETURN_IF_ERROR(readers_[index]->Read(&data));
267 HeapData heap_data;
268 heap_data.key = (*key_)(data);
269 heap_data.data = data;
270 heap_data.index = index;
271 min_heap_.push(heap_data);
272 }
273 return OkStatus();
274 }
275
276 struct HeapData {
277 T key;
278 std::string data;
279 int index;
280 };
281
282 struct HeapDataGreater {
operator ()private_join_and_compute::__anon9105c4f70111::MultiSortedReaderImpl::HeapDataGreater283 bool operator()(const HeapData& lhs, const HeapData& rhs) const {
284 return lhs.key > rhs.key;
285 }
286 };
287
288 const std::function<RecordReader*()> get_reader_;
289 std::unique_ptr<std::function<T(absl::string_view)>> default_key_;
290 std::unique_ptr<std::function<T(absl::string_view)>> key_;
291 std::vector<std::unique_ptr<RecordReader>> readers_;
292 std::priority_queue<HeapData, std::vector<HeapData>, HeapDataGreater>
293 min_heap_;
294 };
295
296 // Writes records to a file one at a time.
297 class RecordWriterImpl : public RecordWriter {
298 public:
RecordWriterImpl(File * file)299 explicit RecordWriterImpl(File* file) : RecordWriter(), out_(file) {}
300
Open(absl::string_view filename)301 Status Open(absl::string_view filename) final {
302 return out_->Open(filename, "w");
303 }
304
Close()305 Status Close() final { return out_->Close(); }
306
Write(absl::string_view raw_data)307 Status Write(absl::string_view raw_data) final {
308 std::string delimited_output;
309 auto string_output =
310 std::make_unique<google::protobuf::io::StringOutputStream>(
311 &delimited_output);
312 auto coded_output =
313 std::make_unique<google::protobuf::io::CodedOutputStream>(
314 string_output.get());
315
316 // Write the delimited output.
317 coded_output->WriteVarint32(raw_data.size());
318 coded_output->WriteString(std::string(raw_data));
319
320 // Force the serialization, which makes delimited_output safe to read.
321 coded_output = nullptr;
322 string_output = nullptr;
323
324 return out_->Write(delimited_output, delimited_output.size());
325 }
326
327 private:
328 std::unique_ptr<File> out_;
329 };
330
331 // Writes lines to a file one at a time.
332 class LineWriterImpl : public LineWriter {
333 public:
LineWriterImpl(File * file)334 explicit LineWriterImpl(File* file) : LineWriter(), out_(file) {}
335
Open(absl::string_view filename)336 Status Open(absl::string_view filename) final {
337 return out_->Open(filename, "w");
338 }
339
Close()340 Status Close() final { return out_->Close(); }
341
Write(absl::string_view line)342 Status Write(absl::string_view line) final {
343 RETURN_IF_ERROR(out_->Write(line.data(), line.size()));
344 return out_->Write("\n", 1);
345 }
346
347 private:
348 std::unique_ptr<File> out_;
349 };
350
351 } // namespace
352
GetLineReader()353 RecordReader* RecordReader::GetLineReader() {
354 return RecordReader::GetLineReader(File::GetFile());
355 }
356
GetLineReader(File * file)357 RecordReader* RecordReader::GetLineReader(File* file) {
358 return new LineReader(file);
359 }
360
GetRecordReader()361 RecordReader* RecordReader::GetRecordReader() {
362 return RecordReader::GetRecordReader(File::GetFile());
363 }
364
GetRecordReader(File * file)365 RecordReader* RecordReader::GetRecordReader(File* file) {
366 return new RecordReaderImpl(file);
367 }
368
Get()369 RecordWriter* RecordWriter::Get() { return RecordWriter::Get(File::GetFile()); }
370
Get(File * file)371 RecordWriter* RecordWriter::Get(File* file) {
372 return new RecordWriterImpl(file);
373 }
374
Get()375 LineWriter* LineWriter::Get() { return LineWriter::Get(File::GetFile()); }
376
Get(File * file)377 LineWriter* LineWriter::Get(File* file) { return new LineWriterImpl(file); }
378
379 template <typename T>
Get()380 MultiSortedReader<T>* MultiSortedReader<T>::Get() {
381 return MultiSortedReader<T>::Get(
382 []() { return RecordReader::GetRecordReader(); });
383 }
384
385 template <>
Get(const std::function<RecordReader * ()> & get_reader)386 MultiSortedReader<std::string>* MultiSortedReader<std::string>::Get(
387 const std::function<RecordReader*()>& get_reader) {
388 return new MultiSortedReaderImpl<std::string>(
389 get_reader,
390 std::make_unique<std::function<std::string(absl::string_view)>>(
391 [](absl::string_view s) { return std::string(s); }));
392 }
393
394 template <>
Get(const std::function<RecordReader * ()> & get_reader)395 MultiSortedReader<int64_t>* MultiSortedReader<int64_t>::Get(
396 const std::function<RecordReader*()>& get_reader) {
397 return new MultiSortedReaderImpl<int64_t>(
398 get_reader, std::make_unique<std::function<int64_t(absl::string_view)>>(
399 [](absl::string_view s) { return 0; }));
400 }
401
402 template class MultiSortedReader<int64_t>;
403 template class MultiSortedReader<std::string>;
404
405 namespace {
406
GetFilename(absl::string_view prefix,int32_t idx)407 std::string GetFilename(absl::string_view prefix, int32_t idx) {
408 return absl::StrCat(prefix, idx);
409 }
410
411 template <typename T>
412 class ShardingWriterImpl : public ShardingWriter<T> {
413 public:
AlreadyUnhealthyError()414 static Status AlreadyUnhealthyError() {
415 return InternalError("ShardingWriter: Already unhealthy.");
416 }
417
ShardingWriterImpl(const std::function<T (absl::string_view)> & get_key,int32_t max_bytes=209715200,std::unique_ptr<RecordWriter> record_writer=absl::WrapUnique (RecordWriter::Get ()))418 explicit ShardingWriterImpl(
419 const std::function<T(absl::string_view)>& get_key,
420 int32_t max_bytes = 209715200, /* 200MB */
421 std::unique_ptr<RecordWriter> record_writer =
422 absl::WrapUnique(RecordWriter::Get()))
423 : get_key_(get_key),
424 record_writer_(std::move(record_writer)),
425 max_bytes_(max_bytes),
426 cache_(),
427 bytes_written_(0),
428 current_file_idx_(0),
429 shard_files_(),
430 healthy_(true),
431 open_(false) {}
432
SetShardPrefix(absl::string_view shard_prefix)433 void SetShardPrefix(absl::string_view shard_prefix) override {
434 absl::MutexLock lock(&mutex_);
435 open_ = true;
436 fnames_prefix_ = std::string(shard_prefix);
437 current_fname_ = GetFilename(fnames_prefix_, current_file_idx_);
438 }
439
Close()440 StatusOr<std::vector<std::string>> Close() override {
441 absl::MutexLock lock(&mutex_);
442
443 auto retval = TryClose();
444
445 // Guarantee that the state is reset, even if TryClose fails.
446 fnames_prefix_ = "";
447 current_fname_ = "";
448 healthy_ = true;
449 cache_.clear();
450 bytes_written_ = 0;
451 shard_files_.clear();
452 current_file_idx_ = 0;
453 open_ = false;
454
455 return retval;
456 }
457
458 // Writes the supplied Record into the file.
459 // Returns true if the write operation was successful.
Write(absl::string_view raw_record)460 Status Write(absl::string_view raw_record) override {
461 absl::MutexLock lock(&mutex_);
462 if (!open_) {
463 return InternalError("Must call SetShardPrefix before calling Write.");
464 }
465 if (!healthy_) {
466 return AlreadyUnhealthyError();
467 }
468 if (bytes_written_ > max_bytes_) {
469 RETURN_IF_ERROR(WriteCacheToFile());
470 }
471 bytes_written_ += raw_record.size();
472 cache_.push_back(std::string(raw_record));
473 return OkStatus();
474 }
475
476 private:
WriteCacheToFile()477 Status WriteCacheToFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
478 if (!healthy_) return AlreadyUnhealthyError();
479 if (cache_.empty()) return OkStatus();
480 cache_.sort([this](absl::string_view r1, absl::string_view r2) {
481 return get_key_(r1) < get_key_(r2);
482 });
483 if (!record_writer_->Open(current_fname_).ok()) {
484 healthy_ = false;
485 return InternalError(
486 absl::StrCat("Cannot open ", current_fname_, " for writing."));
487 }
488 Status status = absl::OkStatus();
489 for (absl::string_view r : cache_) {
490 if (!record_writer_->Write(r).ok()) {
491 healthy_ = false;
492 status = InternalError(
493 absl::StrCat("Cannot write record ", r, " to ", current_fname_));
494
495 break;
496 }
497 }
498 if (!record_writer_->Close().ok()) {
499 if (status.ok()) {
500 status =
501 InternalError(absl::StrCat("Cannot close ", current_fname_, "."));
502 } else {
503 // Preserve the old status message.
504 LOG(WARNING) << "Cannot close " << current_fname_;
505 }
506 }
507
508 shard_files_.push_back(current_fname_);
509 cache_.clear();
510 bytes_written_ = 0;
511 ++current_file_idx_;
512 current_fname_ = GetFilename(fnames_prefix_, current_file_idx_);
513 return status;
514 }
515
TryClose()516 StatusOr<std::vector<std::string>> TryClose()
517 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
518 if (!open_) {
519 return InternalError("Must call SetShardPrefix before calling Close.");
520 }
521 RETURN_IF_ERROR(WriteCacheToFile());
522
523 return {shard_files_};
524 }
525
526 absl::Mutex mutex_;
527 std::function<T(absl::string_view)> get_key_;
528 std::unique_ptr<RecordWriter> record_writer_ ABSL_GUARDED_BY(mutex_);
529 std::string fnames_prefix_ ABSL_GUARDED_BY(mutex_);
530 const int32_t max_bytes_ ABSL_GUARDED_BY(mutex_);
531 std::list<std::string> cache_ ABSL_GUARDED_BY(mutex_);
532 int32_t bytes_written_ ABSL_GUARDED_BY(mutex_);
533 int32_t current_file_idx_ ABSL_GUARDED_BY(mutex_);
534 std::string current_fname_ ABSL_GUARDED_BY(mutex_);
535 std::vector<std::string> shard_files_ ABSL_GUARDED_BY(mutex_);
536 bool healthy_ ABSL_GUARDED_BY(mutex_);
537 bool open_ ABSL_GUARDED_BY(mutex_);
538 };
539
540 } // namespace
541
542 template <typename T>
Get(const std::function<T (absl::string_view)> & get_key,int32_t max_bytes)543 std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get(
544 const std::function<T(absl::string_view)>& get_key, int32_t max_bytes) {
545 return std::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes);
546 }
547
548 // Test only.
549 template <typename T>
Get(const std::function<T (absl::string_view)> & get_key,int32_t max_bytes,std::unique_ptr<RecordWriter> record_writer)550 std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get(
551 const std::function<T(absl::string_view)>& get_key, int32_t max_bytes,
552 std::unique_ptr<RecordWriter> record_writer) {
553 return std::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes,
554 std::move(record_writer));
555 }
556
557 template class ShardingWriter<int64_t>;
558 template class ShardingWriter<std::string>;
559
560 template <typename T>
ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader,std::unique_ptr<RecordWriter> writer)561 ShardMerger<T>::ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader,
562 std::unique_ptr<RecordWriter> writer)
563 : multi_reader_(std::move(multi_reader)), writer_(std::move(writer)) {}
564
565 template <typename T>
Merge(const std::function<T (absl::string_view)> & get_key,const std::vector<std::string> & shard_files,absl::string_view output_file)566 Status ShardMerger<T>::Merge(const std::function<T(absl::string_view)>& get_key,
567 const std::vector<std::string>& shard_files,
568 absl::string_view output_file) {
569 if (shard_files.empty()) {
570 // Create an empty output file.
571 RETURN_IF_ERROR(writer_->Open(output_file));
572 RETURN_IF_ERROR(writer_->Close());
573 }
574
575 // Multi-sorted-read all shards, and write the results to the supplied file.
576 std::vector<std::string> converted_shard_files;
577 converted_shard_files.reserve(shard_files.size());
578 for (const auto& filename : shard_files) {
579 converted_shard_files.push_back(filename);
580 }
581
582 RETURN_IF_ERROR(multi_reader_->Open(converted_shard_files, get_key));
583
584 RETURN_IF_ERROR(writer_->Open(output_file));
585
586 for (std::string record; multi_reader_->HasMore().value();) {
587 RETURN_IF_ERROR(multi_reader_->Read(&record));
588 RETURN_IF_ERROR(writer_->Write(record));
589 }
590 RETURN_IF_ERROR(writer_->Close());
591
592 RETURN_IF_ERROR(multi_reader_->Close());
593
594 return OkStatus();
595 }
596
597 template <typename T>
Delete(std::vector<std::string> shard_files)598 Status ShardMerger<T>::Delete(std::vector<std::string> shard_files) {
599 for (const auto& filename : shard_files) {
600 RETURN_IF_ERROR(DeleteFile(filename));
601 }
602
603 return OkStatus();
604 }
605
606 template class ShardMerger<int64_t>;
607 template class ShardMerger<std::string>;
608
609 } // namespace private_join_and_compute
610