xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/util/recordio.h (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 // Defines file operations.
17 // RecordWriter generates output records that are binary data preceded with a
18 // Varint that explains the size of the records. The records provided to
19 // RecordWriter can be arbitrary binary data, but usually they will be
20 // serialized protobufs.
21 //
22 // RecordReader reads files written in the above format, and is also compatible
23 // with files written using the Java version of parseDelimitedFrom and
24 // writeDelimitedTo.
25 //
26 // LineWriter writes single lines to the output file. LineReader reads single
27 // lines from the input file.
28 //
29 // Note that all classes except ShardingWriter are not thread-safe: concurrent
30 // accesses must be protected by mutexes.
31 
32 #ifndef PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_
33 #define PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_
34 
35 #include <functional>
36 #include <memory>
37 #include <string>
38 #include <vector>
39 
40 #include "absl/memory/memory.h"
41 #include "absl/strings/string_view.h"
42 #include "private_join_and_compute/util/file.h"
43 #include "private_join_and_compute/util/status.inc"
44 
45 namespace private_join_and_compute {
46 
47 // Interface for reading a single file.
48 class RecordReader {
49  public:
50   virtual ~RecordReader() = default;
51 
52   // RecordReader is neither copyable nor movable.
53   RecordReader(const RecordReader&) = delete;
54   RecordReader& operator=(const RecordReader&) = delete;
55 
56   // Opens the given file for reading.
57   virtual Status Open(absl::string_view file_name) = 0;
58 
59   // Closes any file object created via calling SingleFileReader::Open
60   virtual Status Close() = 0;
61 
62   // Returns true if there are more records in the file to be read.
63   virtual StatusOr<bool> HasMore() = 0;
64 
65   // Reads a record from the file (line or binary record).
66   virtual Status Read(std::string* record) = 0;
67 
68   // Returns a RecordReader for reading files line by line.
69   // Caller takes the ownership.
70   static RecordReader* GetLineReader();
71 
72   // Returns a RecordReader for reading files in a record format compatible with
73   // RecordWriter below.
74   // Caller takes the ownership.
75   static RecordReader* GetRecordReader();
76 
77   // Test only.
78   static RecordReader* GetLineReader(File* file);
79   static RecordReader* GetRecordReader(File* file);
80 
81  protected:
82   RecordReader() = default;
83 };
84 
85 // Reads records one at a time in ascending order from multiple files, assuming
86 // each file stores records in ascending order. This class does the merge step
87 // for the external sorting. Templates T supported are string and int64.
88 template <typename T>
89 class MultiSortedReader {
90  public:
91   virtual ~MultiSortedReader() = default;
92 
93   // MultiSortedReader is neither copyable nor movable.
94   MultiSortedReader(const MultiSortedReader&) = delete;
95   MultiSortedReader& operator=(const MultiSortedReader&) = delete;
96 
97   // Opens the files generated with RecordWriterInterface. Records in each file
98   // are assumed to be sorted beforehand.
99   virtual Status Open(const std::vector<std::string>& filenames) = 0;
100 
101   // Same as Open above but also accepts a key function that is used to convert
102   // a string record into a value of type T, used when comparing the records.
103   // Records will be read from the file heads in ascending order of "key".
104   virtual Status Open(const std::vector<std::string>& filenames,
105                       const std::function<T(absl::string_view)>& key) = 0;
106 
107   // Closes the file streams.
108   virtual Status Close() = 0;
109 
110   // Returns true if there are more records in the file to be read.
111   virtual StatusOr<bool> HasMore() = 0;
112 
113   // Reads a record data into <code>data</code> in ascending order.
114   // Erases the <code>data</code> before writing to it.
115   virtual Status Read(std::string* data) = 0;
116 
117   // Same as Read(string* data) but this also puts the index of the file
118   // where the data has been read from if index is not nullptr.
119   // Erases the <code>data</code> before writing to it.
120   virtual Status Read(std::string* data, int* index) = 0;
121 
122   // Returns a MultiSortedReader.
123   // Caller takes the ownership.
124   static MultiSortedReader<T>* Get();
125 
126   // Test only.
127   static MultiSortedReader* Get(
128       const std::function<RecordReader*()>& get_reader);
129 
130  protected:
131   MultiSortedReader() = default;
132 };
133 
134 class RecordWriter {
135  public:
136   virtual ~RecordWriter() = default;
137 
138   // RecordWriter is neither copyable nor movable.
139   RecordWriter(const RecordWriter&) = delete;
140   RecordWriter& operator=(const RecordWriter&) = delete;
141 
142   // Opens the given file for writing records.
143   virtual Status Open(absl::string_view file_name) = 0;
144 
145   // Closes the file stream and returns true if successful.
146   virtual Status Close() = 0;
147 
148   // Writes <code>raw_data</code> into the file as-is, with a delimiter
149   // specifying the data size.
150   virtual Status Write(absl::string_view raw_data) = 0;
151 
152   // Returns a RecordWriter.
153   // Caller takes the ownership.
154   static RecordWriter* Get();
155 
156   // Test only.
157   static RecordWriter* Get(File* file);
158 
159  protected:
160   RecordWriter() = default;
161 };
162 
163 class LineWriter {
164  public:
165   virtual ~LineWriter() = default;
166 
167   // LineWriter is neither copyable nor movable.
168   LineWriter(const LineWriter&) = delete;
169   LineWriter& operator=(const LineWriter&) = delete;
170 
171   // Opens the given file for writing lines.
172   virtual Status Open(absl::string_view file_name) = 0;
173 
174   // Closes the file stream and returns OkStatus if successful.
175   virtual Status Close() = 0;
176 
177   // Writes <code>line</code> into the file, with a trailing newline.
178   // Returns OkStatus if the write operation was successful.
179   virtual Status Write(absl::string_view line) = 0;
180 
181   // Returns a RecordWriter.
182   // Caller takes the ownership.
183   static LineWriter* Get();
184 
185   // Test only.
186   static LineWriter* Get(File* file);
187 
188  protected:
189   LineWriter() = default;
190 };
191 
192 // Writes Records to shard files, with each shard file internally sorted based
193 // on the supplied get_key method.
194 //
195 // This class is thread-safe.
196 template <typename T>
197 class ShardingWriter {
198  public:
199   virtual ~ShardingWriter() = default;
200 
201   // ShardingWriter is neither copyable nor copy-assignable.
202   ShardingWriter(const ShardingWriter&) = delete;
203   ShardingWriter& operator=(const ShardingWriter&) = delete;
204 
205   // Shards will be created with the supplied prefix. Must be called before
206   // Write.
207   virtual void SetShardPrefix(absl::string_view shard_prefix) = 0;
208 
209   // Clears the remaining cache, and returns the list of all shard files that
210   // were written since the last call to SetShardPrefix. Caller is responsible
211   // for merging and deleting shards.
212   //
213   // Returns InternalError if clearing the remaining cache fails.
214   virtual StatusOr<std::vector<std::string>> Close() = 0;
215 
216   // Writes the supplied str into the file.
217   // Implementations need not actually write the record on each call. Rather,
218   // they may cache records until max_bytes records have been cached, at which
219   // point they may sort the cache and write it to a shard file.
220   //
221   // Implementations must return InternalError if writing the cache fails, or
222   // if the shard prefix has not been set.
223   virtual Status Write(absl::string_view raw_data) = 0;
224 
225   // Returns a ShardingWriter that uses the supplied key to compare records.
226   // @param max_bytes: denotes the maximum size of each shard to write.
227   static std::unique_ptr<ShardingWriter> Get(
228       const std::function<T(absl::string_view)>& get_key,
229       int32_t max_bytes = 209715200 /* 200MB */);
230 
231   // Test only.
232   static std::unique_ptr<ShardingWriter> Get(
233       const std::function<T(absl::string_view)>& get_key, int32_t max_bytes,
234       std::unique_ptr<RecordWriter> record_writer);
235 
236  protected:
237   ShardingWriter() = default;
238 };
239 
240 // Utility class to allow merging of sorted shards, and deleting of shards.
241 template <typename T>
242 class ShardMerger {
243  public:
244   explicit ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader =
245                            absl::WrapUnique(MultiSortedReader<T>::Get()),
246                        std::unique_ptr<RecordWriter> writer =
247                            absl::WrapUnique(RecordWriter::Get()));
248 
249   // Merges the supplied shards into a single output file, using the supplied
250   // key.
251   Status Merge(const std::function<T(absl::string_view)>& get_key,
252                const std::vector<std::string>& shard_files,
253                absl::string_view output_file);
254 
255   // Deletes the supplied shard files.
256   Status Delete(std::vector<std::string> shard_files);
257 
258  private:
259   std::unique_ptr<MultiSortedReader<T>> multi_reader_;
260   std::unique_ptr<RecordWriter> writer_;
261 };
262 
263 }  // namespace private_join_and_compute
264 
265 #endif  // PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_
266