xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/util/recordio.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/util/recordio.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/log/log.h"
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/string_view.h"
32 #include "absl/synchronization/mutex.h"
33 #include "private_join_and_compute/util/status.inc"
34 #include "src/google/protobuf/io/coded_stream.h"
35 #include "src/google/protobuf/io/zero_copy_stream_impl_lite.h"
36 
37 namespace private_join_and_compute {
38 
39 namespace {
40 
41 // Max. size of a Varint32 (from proto references).
42 const uint32_t kMaxVarint32Size = 5;
43 
44 // Tries to read a Varint32 from the front of a given file. Returns false if the
45 // reading fails.
ExtractVarint32(File * file)46 StatusOr<uint32_t> ExtractVarint32(File* file) {
47   // Keep reading a single character until one is found such that the top bit is
48   // 0;
49   std::string bytes_read = "";
50 
51   size_t current_byte = 0;
52   ASSIGN_OR_RETURN(auto has_more, file->HasMore());
53   while (current_byte < kMaxVarint32Size && has_more) {
54     auto maybe_last_byte = file->Read(1);
55     if (!maybe_last_byte.ok()) {
56       return maybe_last_byte.status();
57     }
58 
59     bytes_read += maybe_last_byte.value();
60     if (!(bytes_read.data()[current_byte] & 0x80)) {
61       break;
62     }
63     current_byte++;
64     // If we read the max number of bits and never found a "terminating" byte,
65     // return false.
66     if (current_byte >= kMaxVarint32Size) {
67       return InvalidArgumentError(
68           "ExtractVarint32: Failed to extract a Varint after reading max "
69           "number "
70           "of bytes.");
71     }
72     ASSIGN_OR_RETURN(has_more, file->HasMore());
73   }
74 
75   google::protobuf::io::ArrayInputStream arrayInputStream(bytes_read.data(),
76                                                           bytes_read.size());
77   google::protobuf::io::CodedInputStream codedInputStream(&arrayInputStream);
78   uint32_t result;
79   codedInputStream.ReadVarint32(&result);
80 
81   return result;
82 }
83 
84 // Reads records from a file one at a time.
85 class RecordReaderImpl : public RecordReader {
86  public:
RecordReaderImpl(File * file)87   explicit RecordReaderImpl(File* file) : RecordReader(), in_(file) {}
88 
Open(absl::string_view filename)89   Status Open(absl::string_view filename) final {
90     return in_->Open(filename, "r");
91   }
92 
Close()93   Status Close() final { return in_->Close(); }
94 
HasMore()95   StatusOr<bool> HasMore() final {
96     auto status_or_has_more = in_->HasMore();
97     if (!status_or_has_more.ok()) {
98       LOG(ERROR) << status_or_has_more.status();
99     }
100     return status_or_has_more;
101   }
102 
Read(std::string * raw_data)103   Status Read(std::string* raw_data) final {
104     raw_data->erase();
105     auto maybe_record_size = ExtractVarint32(in_.get());
106     if (!maybe_record_size.ok()) {
107       LOG(ERROR) << "RecordReader::Read: Couldn't read record size: "
108                  << maybe_record_size.status();
109       return maybe_record_size.status();
110     }
111     uint32_t record_size = maybe_record_size.value();
112 
113     auto status_or_data = in_->Read(record_size);
114     if (!status_or_data.ok()) {
115       LOG(ERROR) << status_or_data.status();
116       return status_or_data.status();
117     }
118 
119     raw_data->append(status_or_data.value());
120     return OkStatus();
121   }
122 
123  private:
124   std::unique_ptr<File> in_;
125 };
126 
127 // Reads lines from a file one at a time.
128 class LineReader : public RecordReader {
129  public:
LineReader(File * file)130   explicit LineReader(File* file) : RecordReader(), in_(file) {}
131 
Open(absl::string_view filename)132   Status Open(absl::string_view filename) final {
133     return in_->Open(filename, "r");
134   }
135 
Close()136   Status Close() final { return in_->Close(); }
137 
HasMore()138   StatusOr<bool> HasMore() final { return in_->HasMore(); }
139 
Read(std::string * line)140   Status Read(std::string* line) final {
141     line->erase();
142     auto status_or_line = in_->ReadLine();
143     if (!status_or_line.ok()) {
144       LOG(ERROR) << status_or_line.status();
145       return status_or_line.status();
146     }
147     line->append(status_or_line.value());
148     return OkStatus();
149   }
150 
151  private:
152   std::unique_ptr<File> in_;
153 };
154 
155 template <typename T>
156 class MultiSortedReaderImpl : public MultiSortedReader<T> {
157  public:
MultiSortedReaderImpl(const std::function<RecordReader * ()> & get_reader,std::unique_ptr<std::function<T (absl::string_view)>> default_key=nullptr)158   explicit MultiSortedReaderImpl(
159       const std::function<RecordReader*()>& get_reader,
160       std::unique_ptr<std::function<T(absl::string_view)>> default_key =
161           nullptr)
162       : MultiSortedReader<T>(),
163         get_reader_(get_reader),
164         default_key_(std::move(default_key)),
165         key_(nullptr) {}
166 
Open(const std::vector<std::string> & filenames)167   Status Open(const std::vector<std::string>& filenames) override {
168     if (default_key_ == nullptr) {
169       return InvalidArgumentError("The sorting key is null.");
170     }
171     return Open(filenames, *default_key_);
172   }
173 
Open(const std::vector<std::string> & filenames,const std::function<T (absl::string_view)> & key)174   Status Open(const std::vector<std::string>& filenames,
175               const std::function<T(absl::string_view)>& key) override {
176     if (!readers_.empty()) {
177       return InternalError("There are files not closed, call Close() first.");
178     }
179     key_ = std::make_unique<std::function<T(absl::string_view)>>(key);
180     for (size_t i = 0; i < filenames.size(); ++i) {
181       this->readers_.push_back(std::unique_ptr<RecordReader>(get_reader_()));
182       auto open_status = this->readers_.back()->Open(filenames[i]);
183       if (!open_status.ok()) {
184         // Try to close the opened ones.
185         for (int j = i - 1; j >= 0; --j) {
186           // If closing fails as well, then any call to Open will fail as well
187           // since some of the files will remain opened.
188           auto status = this->readers_[j]->Close();
189           if (!status.ok()) {
190             LOG(ERROR) << "Error closing file " << status;
191           }
192           this->readers_.pop_back();
193         }
194         return open_status;
195       }
196     }
197     return OkStatus();
198   }
199 
Close()200   Status Close() override {
201     Status status = OkStatus();
202     bool ret_val =
203         std::all_of(readers_.begin(), readers_.end(),
204                     [&status](std::unique_ptr<RecordReader>& reader) {
205                       Status close_status = reader->Close();
206                       if (!close_status.ok()) {
207                         status = close_status;
208                         return false;
209                       } else {
210                         return true;
211                       }
212                     });
213     if (ret_val) {
214       readers_ = std::vector<std::unique_ptr<RecordReader>>();
215       min_heap_ = std::priority_queue<HeapData, std::vector<HeapData>,
216                                       HeapDataGreater>();
217     }
218     return status;
219   }
220 
HasMore()221   StatusOr<bool> HasMore() override {
222     if (!min_heap_.empty()) {
223       return true;
224     }
225     Status status = OkStatus();
226     for (const auto& reader : readers_) {
227       auto status_or_has_more = reader->HasMore();
228       if (status_or_has_more.ok()) {
229         if (status_or_has_more.value()) {
230           return true;
231         }
232       } else {
233         status = status_or_has_more.status();
234       }
235     }
236     if (status.ok()) {
237       // None of the readers has more.
238       return false;
239     }
240     return status;
241   }
242 
Read(std::string * data)243   Status Read(std::string* data) override { return Read(data, nullptr); }
244 
Read(std::string * data,int * index)245   Status Read(std::string* data, int* index) override {
246     if (min_heap_.empty()) {
247       for (size_t i = 0; i < readers_.size(); ++i) {
248         RETURN_IF_ERROR(this->ReadHeapDataFromReader(i));
249       }
250     }
251     HeapData ret_data = min_heap_.top();
252     data->assign(ret_data.data);
253     if (index != nullptr) *index = ret_data.index;
254     min_heap_.pop();
255     return this->ReadHeapDataFromReader(ret_data.index);
256   }
257 
258  private:
ReadHeapDataFromReader(int index)259   Status ReadHeapDataFromReader(int index) {
260     std::string data;
261     auto status_or_has_more = readers_[index]->HasMore();
262     if (!status_or_has_more.ok()) {
263       return status_or_has_more.status();
264     }
265     if (status_or_has_more.value()) {
266       RETURN_IF_ERROR(readers_[index]->Read(&data));
267       HeapData heap_data;
268       heap_data.key = (*key_)(data);
269       heap_data.data = data;
270       heap_data.index = index;
271       min_heap_.push(heap_data);
272     }
273     return OkStatus();
274   }
275 
276   struct HeapData {
277     T key;
278     std::string data;
279     int index;
280   };
281 
282   struct HeapDataGreater {
operator ()private_join_and_compute::__anon9105c4f70111::MultiSortedReaderImpl::HeapDataGreater283     bool operator()(const HeapData& lhs, const HeapData& rhs) const {
284       return lhs.key > rhs.key;
285     }
286   };
287 
288   const std::function<RecordReader*()> get_reader_;
289   std::unique_ptr<std::function<T(absl::string_view)>> default_key_;
290   std::unique_ptr<std::function<T(absl::string_view)>> key_;
291   std::vector<std::unique_ptr<RecordReader>> readers_;
292   std::priority_queue<HeapData, std::vector<HeapData>, HeapDataGreater>
293       min_heap_;
294 };
295 
296 // Writes records to a file one at a time.
297 class RecordWriterImpl : public RecordWriter {
298  public:
RecordWriterImpl(File * file)299   explicit RecordWriterImpl(File* file) : RecordWriter(), out_(file) {}
300 
Open(absl::string_view filename)301   Status Open(absl::string_view filename) final {
302     return out_->Open(filename, "w");
303   }
304 
Close()305   Status Close() final { return out_->Close(); }
306 
Write(absl::string_view raw_data)307   Status Write(absl::string_view raw_data) final {
308     std::string delimited_output;
309     auto string_output =
310         std::make_unique<google::protobuf::io::StringOutputStream>(
311             &delimited_output);
312     auto coded_output =
313         std::make_unique<google::protobuf::io::CodedOutputStream>(
314             string_output.get());
315 
316     // Write the delimited output.
317     coded_output->WriteVarint32(raw_data.size());
318     coded_output->WriteString(std::string(raw_data));
319 
320     // Force the serialization, which makes delimited_output safe to read.
321     coded_output = nullptr;
322     string_output = nullptr;
323 
324     return out_->Write(delimited_output, delimited_output.size());
325   }
326 
327  private:
328   std::unique_ptr<File> out_;
329 };
330 
331 // Writes lines to a file one at a time.
332 class LineWriterImpl : public LineWriter {
333  public:
LineWriterImpl(File * file)334   explicit LineWriterImpl(File* file) : LineWriter(), out_(file) {}
335 
Open(absl::string_view filename)336   Status Open(absl::string_view filename) final {
337     return out_->Open(filename, "w");
338   }
339 
Close()340   Status Close() final { return out_->Close(); }
341 
Write(absl::string_view line)342   Status Write(absl::string_view line) final {
343     RETURN_IF_ERROR(out_->Write(line.data(), line.size()));
344     return out_->Write("\n", 1);
345   }
346 
347  private:
348   std::unique_ptr<File> out_;
349 };
350 
351 }  // namespace
352 
GetLineReader()353 RecordReader* RecordReader::GetLineReader() {
354   return RecordReader::GetLineReader(File::GetFile());
355 }
356 
GetLineReader(File * file)357 RecordReader* RecordReader::GetLineReader(File* file) {
358   return new LineReader(file);
359 }
360 
GetRecordReader()361 RecordReader* RecordReader::GetRecordReader() {
362   return RecordReader::GetRecordReader(File::GetFile());
363 }
364 
GetRecordReader(File * file)365 RecordReader* RecordReader::GetRecordReader(File* file) {
366   return new RecordReaderImpl(file);
367 }
368 
Get()369 RecordWriter* RecordWriter::Get() { return RecordWriter::Get(File::GetFile()); }
370 
Get(File * file)371 RecordWriter* RecordWriter::Get(File* file) {
372   return new RecordWriterImpl(file);
373 }
374 
Get()375 LineWriter* LineWriter::Get() { return LineWriter::Get(File::GetFile()); }
376 
Get(File * file)377 LineWriter* LineWriter::Get(File* file) { return new LineWriterImpl(file); }
378 
379 template <typename T>
Get()380 MultiSortedReader<T>* MultiSortedReader<T>::Get() {
381   return MultiSortedReader<T>::Get(
382       []() { return RecordReader::GetRecordReader(); });
383 }
384 
385 template <>
Get(const std::function<RecordReader * ()> & get_reader)386 MultiSortedReader<std::string>* MultiSortedReader<std::string>::Get(
387     const std::function<RecordReader*()>& get_reader) {
388   return new MultiSortedReaderImpl<std::string>(
389       get_reader,
390       std::make_unique<std::function<std::string(absl::string_view)>>(
391           [](absl::string_view s) { return std::string(s); }));
392 }
393 
394 template <>
Get(const std::function<RecordReader * ()> & get_reader)395 MultiSortedReader<int64_t>* MultiSortedReader<int64_t>::Get(
396     const std::function<RecordReader*()>& get_reader) {
397   return new MultiSortedReaderImpl<int64_t>(
398       get_reader, std::make_unique<std::function<int64_t(absl::string_view)>>(
399                       [](absl::string_view s) { return 0; }));
400 }
401 
402 template class MultiSortedReader<int64_t>;
403 template class MultiSortedReader<std::string>;
404 
405 namespace {
406 
GetFilename(absl::string_view prefix,int32_t idx)407 std::string GetFilename(absl::string_view prefix, int32_t idx) {
408   return absl::StrCat(prefix, idx);
409 }
410 
411 template <typename T>
412 class ShardingWriterImpl : public ShardingWriter<T> {
413  public:
AlreadyUnhealthyError()414   static Status AlreadyUnhealthyError() {
415     return InternalError("ShardingWriter: Already unhealthy.");
416   }
417 
ShardingWriterImpl(const std::function<T (absl::string_view)> & get_key,int32_t max_bytes=209715200,std::unique_ptr<RecordWriter> record_writer=absl::WrapUnique (RecordWriter::Get ()))418   explicit ShardingWriterImpl(
419       const std::function<T(absl::string_view)>& get_key,
420       int32_t max_bytes = 209715200, /* 200MB */
421       std::unique_ptr<RecordWriter> record_writer =
422           absl::WrapUnique(RecordWriter::Get()))
423       : get_key_(get_key),
424         record_writer_(std::move(record_writer)),
425         max_bytes_(max_bytes),
426         cache_(),
427         bytes_written_(0),
428         current_file_idx_(0),
429         shard_files_(),
430         healthy_(true),
431         open_(false) {}
432 
SetShardPrefix(absl::string_view shard_prefix)433   void SetShardPrefix(absl::string_view shard_prefix) override {
434     absl::MutexLock lock(&mutex_);
435     open_ = true;
436     fnames_prefix_ = std::string(shard_prefix);
437     current_fname_ = GetFilename(fnames_prefix_, current_file_idx_);
438   }
439 
Close()440   StatusOr<std::vector<std::string>> Close() override {
441     absl::MutexLock lock(&mutex_);
442 
443     auto retval = TryClose();
444 
445     // Guarantee that the state is reset, even if TryClose fails.
446     fnames_prefix_ = "";
447     current_fname_ = "";
448     healthy_ = true;
449     cache_.clear();
450     bytes_written_ = 0;
451     shard_files_.clear();
452     current_file_idx_ = 0;
453     open_ = false;
454 
455     return retval;
456   }
457 
458   // Writes the supplied Record into the file.
459   // Returns true if the write operation was successful.
Write(absl::string_view raw_record)460   Status Write(absl::string_view raw_record) override {
461     absl::MutexLock lock(&mutex_);
462     if (!open_) {
463       return InternalError("Must call SetShardPrefix before calling Write.");
464     }
465     if (!healthy_) {
466       return AlreadyUnhealthyError();
467     }
468     if (bytes_written_ > max_bytes_) {
469       RETURN_IF_ERROR(WriteCacheToFile());
470     }
471     bytes_written_ += raw_record.size();
472     cache_.push_back(std::string(raw_record));
473     return OkStatus();
474   }
475 
476  private:
WriteCacheToFile()477   Status WriteCacheToFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
478     if (!healthy_) return AlreadyUnhealthyError();
479     if (cache_.empty()) return OkStatus();
480     cache_.sort([this](absl::string_view r1, absl::string_view r2) {
481       return get_key_(r1) < get_key_(r2);
482     });
483     if (!record_writer_->Open(current_fname_).ok()) {
484       healthy_ = false;
485       return InternalError(
486           absl::StrCat("Cannot open ", current_fname_, " for writing."));
487     }
488     Status status = absl::OkStatus();
489     for (absl::string_view r : cache_) {
490       if (!record_writer_->Write(r).ok()) {
491         healthy_ = false;
492         status = InternalError(
493             absl::StrCat("Cannot write record ", r, " to ", current_fname_));
494 
495         break;
496       }
497     }
498     if (!record_writer_->Close().ok()) {
499       if (status.ok()) {
500         status =
501             InternalError(absl::StrCat("Cannot close ", current_fname_, "."));
502       } else {
503         // Preserve the old status message.
504         LOG(WARNING) << "Cannot close " << current_fname_;
505       }
506     }
507 
508     shard_files_.push_back(current_fname_);
509     cache_.clear();
510     bytes_written_ = 0;
511     ++current_file_idx_;
512     current_fname_ = GetFilename(fnames_prefix_, current_file_idx_);
513     return status;
514   }
515 
TryClose()516   StatusOr<std::vector<std::string>> TryClose()
517       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
518     if (!open_) {
519       return InternalError("Must call SetShardPrefix before calling Close.");
520     }
521     RETURN_IF_ERROR(WriteCacheToFile());
522 
523     return {shard_files_};
524   }
525 
526   absl::Mutex mutex_;
527   std::function<T(absl::string_view)> get_key_;
528   std::unique_ptr<RecordWriter> record_writer_ ABSL_GUARDED_BY(mutex_);
529   std::string fnames_prefix_ ABSL_GUARDED_BY(mutex_);
530   const int32_t max_bytes_ ABSL_GUARDED_BY(mutex_);
531   std::list<std::string> cache_ ABSL_GUARDED_BY(mutex_);
532   int32_t bytes_written_ ABSL_GUARDED_BY(mutex_);
533   int32_t current_file_idx_ ABSL_GUARDED_BY(mutex_);
534   std::string current_fname_ ABSL_GUARDED_BY(mutex_);
535   std::vector<std::string> shard_files_ ABSL_GUARDED_BY(mutex_);
536   bool healthy_ ABSL_GUARDED_BY(mutex_);
537   bool open_ ABSL_GUARDED_BY(mutex_);
538 };
539 
540 }  // namespace
541 
542 template <typename T>
Get(const std::function<T (absl::string_view)> & get_key,int32_t max_bytes)543 std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get(
544     const std::function<T(absl::string_view)>& get_key, int32_t max_bytes) {
545   return std::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes);
546 }
547 
548 // Test only.
549 template <typename T>
Get(const std::function<T (absl::string_view)> & get_key,int32_t max_bytes,std::unique_ptr<RecordWriter> record_writer)550 std::unique_ptr<ShardingWriter<T>> ShardingWriter<T>::Get(
551     const std::function<T(absl::string_view)>& get_key, int32_t max_bytes,
552     std::unique_ptr<RecordWriter> record_writer) {
553   return std::make_unique<ShardingWriterImpl<T>>(get_key, max_bytes,
554                                                  std::move(record_writer));
555 }
556 
557 template class ShardingWriter<int64_t>;
558 template class ShardingWriter<std::string>;
559 
560 template <typename T>
ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader,std::unique_ptr<RecordWriter> writer)561 ShardMerger<T>::ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader,
562                             std::unique_ptr<RecordWriter> writer)
563     : multi_reader_(std::move(multi_reader)), writer_(std::move(writer)) {}
564 
565 template <typename T>
Merge(const std::function<T (absl::string_view)> & get_key,const std::vector<std::string> & shard_files,absl::string_view output_file)566 Status ShardMerger<T>::Merge(const std::function<T(absl::string_view)>& get_key,
567                              const std::vector<std::string>& shard_files,
568                              absl::string_view output_file) {
569   if (shard_files.empty()) {
570     // Create an empty output file.
571     RETURN_IF_ERROR(writer_->Open(output_file));
572     RETURN_IF_ERROR(writer_->Close());
573   }
574 
575   // Multi-sorted-read all shards, and write the results to the supplied file.
576   std::vector<std::string> converted_shard_files;
577   converted_shard_files.reserve(shard_files.size());
578   for (const auto& filename : shard_files) {
579     converted_shard_files.push_back(filename);
580   }
581 
582   RETURN_IF_ERROR(multi_reader_->Open(converted_shard_files, get_key));
583 
584   RETURN_IF_ERROR(writer_->Open(output_file));
585 
586   for (std::string record; multi_reader_->HasMore().value();) {
587     RETURN_IF_ERROR(multi_reader_->Read(&record));
588     RETURN_IF_ERROR(writer_->Write(record));
589   }
590   RETURN_IF_ERROR(writer_->Close());
591 
592   RETURN_IF_ERROR(multi_reader_->Close());
593 
594   return OkStatus();
595 }
596 
597 template <typename T>
Delete(std::vector<std::string> shard_files)598 Status ShardMerger<T>::Delete(std::vector<std::string> shard_files) {
599   for (const auto& filename : shard_files) {
600     RETURN_IF_ERROR(DeleteFile(filename));
601   }
602 
603   return OkStatus();
604 }
605 
606 template class ShardMerger<int64_t>;
607 template class ShardMerger<std::string>;
608 
609 }  // namespace private_join_and_compute
610