1 /* 2 * Copyright 2019 Google LLC. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * https://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 // Defines file operations. 17 // RecordWriter generates output records that are binary data preceded with a 18 // Varint that explains the size of the records. The records provided to 19 // RecordWriter can be arbitrary binary data, but usually they will be 20 // serialized protobufs. 21 // 22 // RecordReader reads files written in the above format, and is also compatible 23 // with files written using the Java version of parseDelimitedFrom and 24 // writeDelimitedTo. 25 // 26 // LineWriter writes single lines to the output file. LineReader reads single 27 // lines from the input file. 28 // 29 // Note that all classes except ShardingWriter are not thread-safe: concurrent 30 // accesses must be protected by mutexes. 31 32 #ifndef PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_ 33 #define PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_ 34 35 #include <functional> 36 #include <memory> 37 #include <string> 38 #include <vector> 39 40 #include "absl/memory/memory.h" 41 #include "absl/strings/string_view.h" 42 #include "private_join_and_compute/util/file.h" 43 #include "private_join_and_compute/util/status.inc" 44 45 namespace private_join_and_compute { 46 47 // Interface for reading a single file. 48 class RecordReader { 49 public: 50 virtual ~RecordReader() = default; 51 52 // RecordReader is neither copyable nor movable. 53 RecordReader(const RecordReader&) = delete; 54 RecordReader& operator=(const RecordReader&) = delete; 55 56 // Opens the given file for reading. 57 virtual Status Open(absl::string_view file_name) = 0; 58 59 // Closes any file object created via calling SingleFileReader::Open 60 virtual Status Close() = 0; 61 62 // Returns true if there are more records in the file to be read. 63 virtual StatusOr<bool> HasMore() = 0; 64 65 // Reads a record from the file (line or binary record). 66 virtual Status Read(std::string* record) = 0; 67 68 // Returns a RecordReader for reading files line by line. 69 // Caller takes the ownership. 70 static RecordReader* GetLineReader(); 71 72 // Returns a RecordReader for reading files in a record format compatible with 73 // RecordWriter below. 74 // Caller takes the ownership. 75 static RecordReader* GetRecordReader(); 76 77 // Test only. 78 static RecordReader* GetLineReader(File* file); 79 static RecordReader* GetRecordReader(File* file); 80 81 protected: 82 RecordReader() = default; 83 }; 84 85 // Reads records one at a time in ascending order from multiple files, assuming 86 // each file stores records in ascending order. This class does the merge step 87 // for the external sorting. Templates T supported are string and int64. 88 template <typename T> 89 class MultiSortedReader { 90 public: 91 virtual ~MultiSortedReader() = default; 92 93 // MultiSortedReader is neither copyable nor movable. 94 MultiSortedReader(const MultiSortedReader&) = delete; 95 MultiSortedReader& operator=(const MultiSortedReader&) = delete; 96 97 // Opens the files generated with RecordWriterInterface. Records in each file 98 // are assumed to be sorted beforehand. 99 virtual Status Open(const std::vector<std::string>& filenames) = 0; 100 101 // Same as Open above but also accepts a key function that is used to convert 102 // a string record into a value of type T, used when comparing the records. 103 // Records will be read from the file heads in ascending order of "key". 104 virtual Status Open(const std::vector<std::string>& filenames, 105 const std::function<T(absl::string_view)>& key) = 0; 106 107 // Closes the file streams. 108 virtual Status Close() = 0; 109 110 // Returns true if there are more records in the file to be read. 111 virtual StatusOr<bool> HasMore() = 0; 112 113 // Reads a record data into <code>data</code> in ascending order. 114 // Erases the <code>data</code> before writing to it. 115 virtual Status Read(std::string* data) = 0; 116 117 // Same as Read(string* data) but this also puts the index of the file 118 // where the data has been read from if index is not nullptr. 119 // Erases the <code>data</code> before writing to it. 120 virtual Status Read(std::string* data, int* index) = 0; 121 122 // Returns a MultiSortedReader. 123 // Caller takes the ownership. 124 static MultiSortedReader<T>* Get(); 125 126 // Test only. 127 static MultiSortedReader* Get( 128 const std::function<RecordReader*()>& get_reader); 129 130 protected: 131 MultiSortedReader() = default; 132 }; 133 134 class RecordWriter { 135 public: 136 virtual ~RecordWriter() = default; 137 138 // RecordWriter is neither copyable nor movable. 139 RecordWriter(const RecordWriter&) = delete; 140 RecordWriter& operator=(const RecordWriter&) = delete; 141 142 // Opens the given file for writing records. 143 virtual Status Open(absl::string_view file_name) = 0; 144 145 // Closes the file stream and returns true if successful. 146 virtual Status Close() = 0; 147 148 // Writes <code>raw_data</code> into the file as-is, with a delimiter 149 // specifying the data size. 150 virtual Status Write(absl::string_view raw_data) = 0; 151 152 // Returns a RecordWriter. 153 // Caller takes the ownership. 154 static RecordWriter* Get(); 155 156 // Test only. 157 static RecordWriter* Get(File* file); 158 159 protected: 160 RecordWriter() = default; 161 }; 162 163 class LineWriter { 164 public: 165 virtual ~LineWriter() = default; 166 167 // LineWriter is neither copyable nor movable. 168 LineWriter(const LineWriter&) = delete; 169 LineWriter& operator=(const LineWriter&) = delete; 170 171 // Opens the given file for writing lines. 172 virtual Status Open(absl::string_view file_name) = 0; 173 174 // Closes the file stream and returns OkStatus if successful. 175 virtual Status Close() = 0; 176 177 // Writes <code>line</code> into the file, with a trailing newline. 178 // Returns OkStatus if the write operation was successful. 179 virtual Status Write(absl::string_view line) = 0; 180 181 // Returns a RecordWriter. 182 // Caller takes the ownership. 183 static LineWriter* Get(); 184 185 // Test only. 186 static LineWriter* Get(File* file); 187 188 protected: 189 LineWriter() = default; 190 }; 191 192 // Writes Records to shard files, with each shard file internally sorted based 193 // on the supplied get_key method. 194 // 195 // This class is thread-safe. 196 template <typename T> 197 class ShardingWriter { 198 public: 199 virtual ~ShardingWriter() = default; 200 201 // ShardingWriter is neither copyable nor copy-assignable. 202 ShardingWriter(const ShardingWriter&) = delete; 203 ShardingWriter& operator=(const ShardingWriter&) = delete; 204 205 // Shards will be created with the supplied prefix. Must be called before 206 // Write. 207 virtual void SetShardPrefix(absl::string_view shard_prefix) = 0; 208 209 // Clears the remaining cache, and returns the list of all shard files that 210 // were written since the last call to SetShardPrefix. Caller is responsible 211 // for merging and deleting shards. 212 // 213 // Returns InternalError if clearing the remaining cache fails. 214 virtual StatusOr<std::vector<std::string>> Close() = 0; 215 216 // Writes the supplied str into the file. 217 // Implementations need not actually write the record on each call. Rather, 218 // they may cache records until max_bytes records have been cached, at which 219 // point they may sort the cache and write it to a shard file. 220 // 221 // Implementations must return InternalError if writing the cache fails, or 222 // if the shard prefix has not been set. 223 virtual Status Write(absl::string_view raw_data) = 0; 224 225 // Returns a ShardingWriter that uses the supplied key to compare records. 226 // @param max_bytes: denotes the maximum size of each shard to write. 227 static std::unique_ptr<ShardingWriter> Get( 228 const std::function<T(absl::string_view)>& get_key, 229 int32_t max_bytes = 209715200 /* 200MB */); 230 231 // Test only. 232 static std::unique_ptr<ShardingWriter> Get( 233 const std::function<T(absl::string_view)>& get_key, int32_t max_bytes, 234 std::unique_ptr<RecordWriter> record_writer); 235 236 protected: 237 ShardingWriter() = default; 238 }; 239 240 // Utility class to allow merging of sorted shards, and deleting of shards. 241 template <typename T> 242 class ShardMerger { 243 public: 244 explicit ShardMerger(std::unique_ptr<MultiSortedReader<T>> multi_reader = 245 absl::WrapUnique(MultiSortedReader<T>::Get()), 246 std::unique_ptr<RecordWriter> writer = 247 absl::WrapUnique(RecordWriter::Get())); 248 249 // Merges the supplied shards into a single output file, using the supplied 250 // key. 251 Status Merge(const std::function<T(absl::string_view)>& get_key, 252 const std::vector<std::string>& shard_files, 253 absl::string_view output_file); 254 255 // Deletes the supplied shard files. 256 Status Delete(std::vector<std::string> shard_files); 257 258 private: 259 std::unique_ptr<MultiSortedReader<T>> multi_reader_; 260 std::unique_ptr<RecordWriter> writer_; 261 }; 262 263 } // namespace private_join_and_compute 264 265 #endif // PRIVATE_JOIN_AND_COMPUTE_INTERNAL_UTIL_RECORDIO_H_ 266