1 /* 2 * Copyright 2019 Google LLC. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * https://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 #ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_ 17 #define PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_ 18 19 #include <algorithm> 20 #include <functional> 21 #include <future> // NOLINT 22 #include <memory> 23 #include <string> 24 25 #include "absl/strings/string_view.h" 26 #include "private_join_and_compute/util/process_record_file_parameters.h" 27 #include "private_join_and_compute/util/proto_util.h" 28 #include "private_join_and_compute/util/recordio.h" 29 #include "private_join_and_compute/util/status.inc" 30 31 namespace private_join_and_compute::util::process_file_util { 32 33 // Applies the function record_transformer() to all the records in input_file, 34 // and writes the resulting records to output_file, sorted by the key returned 35 // by the provided get_sorting_key_function. By default, records are sorted by 36 // their string representation. 37 // input_file must contain records of type InputFile. 38 // output_file contains records of type OutputFile. 39 // The files are processed in parallel using the number of threads specified by 40 // the ProcessRecordFileParameters. 41 // The file is processed in chunks of at most params.data_chunk_size values: 42 // read a chunk, apply function record_transformer() in parallel using 43 // params.thread_count threads, get the output values returned by each thread, 44 // and write them to file. Process the next chunk until there are no more values 45 // to read. 46 template <typename InputType, typename OutputType> 47 Status ProcessRecordFile( 48 const std::function<StatusOr<OutputType>(InputType)>& record_transformer, 49 const ProcessRecordFileParameters& params, absl::string_view input_file, 50 absl::string_view output_file, 51 const std::function<std::string(absl::string_view)>& 52 get_sorting_key_function = [](absl::string_view raw_record) { 53 return std::string(raw_record); 54 }) { 55 auto reader = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader()); 56 RETURN_IF_ERROR(reader->Open(input_file)); 57 58 auto writer = ShardingWriter<std::string>::Get(get_sorting_key_function); 59 writer->SetShardPrefix(output_file); 60 61 std::string raw_record; 62 size_t num_records_read = 0; 63 // Process the file in chunks of at most data_chunk_size values: read a 64 // chunk, process it in parallel using the number of available threads, get 65 // the values returned by each thread, and write them to file. 66 // Process the next chunk until there are no more values to read. 67 ASSIGN_OR_RETURN(bool has_more, reader->HasMore()); 68 while (has_more) { 69 // Read the next chunk to process in parallel. 70 num_records_read = 0; 71 std::vector<InputType> chunk; 72 while (num_records_read < params.data_chunk_size && has_more) { 73 RETURN_IF_ERROR(reader->Read(&raw_record)); 74 chunk.push_back(ProtoUtils::FromString<InputType>(raw_record)); 75 num_records_read++; 76 ASSIGN_OR_RETURN(has_more, reader->HasMore()); 77 } 78 79 // The max number of items each thread will process. 80 size_t per_thread_size = 81 (chunk.size() + params.thread_count - 1) / params.thread_count; 82 83 // Stores the results of each thread. 84 // Each thread processes a portion of chunk. 85 std::vector<std::future<StatusOr<std::vector<OutputType>>>> futures; 86 for (uint32_t j = 0; j < params.thread_count; j++) { 87 size_t start = j * per_thread_size; 88 size_t end = std::min((j + 1) * per_thread_size, num_records_read); 89 // std::launch::async ensures multi-thread. 90 futures.push_back(std::async( 91 std::launch::async, 92 [&chunk, start, end, 93 record_transformer]() -> StatusOr<std::vector<OutputType>> { 94 std::vector<OutputType> processes_chunk; 95 for (size_t i = start; i < end; i++) { 96 ASSIGN_OR_RETURN(auto processed_record, 97 record_transformer(chunk.at(i))); 98 processes_chunk.push_back(std::move(processed_record)); 99 } 100 return processes_chunk; 101 })); 102 } 103 104 // Write the processed values returned by each thread to file. 105 writer->SetShardPrefix(output_file); 106 int index = 0; 107 for (auto& future : futures) { 108 index++; 109 ASSIGN_OR_RETURN(auto records, future.get()); 110 for (const auto& record : records) { 111 RETURN_IF_ERROR(writer->Write(ProtoUtils::ToString(record))); 112 } 113 } 114 } 115 RETURN_IF_ERROR(reader->Close()); 116 117 // Merge all the processed chunks into one output file and delete intermediate 118 // chunk files. 119 ASSIGN_OR_RETURN(auto shard_files, writer->Close()); 120 ShardMerger<std::string> merger; 121 RETURN_IF_ERROR( 122 merger.Merge(get_sorting_key_function, shard_files, output_file)); 123 RETURN_IF_ERROR(merger.Delete(shard_files)); 124 125 return OkStatus(); 126 } 127 128 } // namespace private_join_and_compute::util::process_file_util 129 130 #endif // PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_ 131