1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_
17 #define PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_
18 
19 #include <algorithm>
20 #include <functional>
21 #include <future>  // NOLINT
22 #include <memory>
23 #include <string>
24 
25 #include "absl/strings/string_view.h"
26 #include "private_join_and_compute/util/process_record_file_parameters.h"
27 #include "private_join_and_compute/util/proto_util.h"
28 #include "private_join_and_compute/util/recordio.h"
29 #include "private_join_and_compute/util/status.inc"
30 
31 namespace private_join_and_compute::util::process_file_util {
32 
33 // Applies the function record_transformer() to all the records in input_file,
34 // and writes the resulting records to output_file, sorted by the key returned
35 // by the provided get_sorting_key_function. By default, records are sorted by
36 // their string representation.
37 // input_file must contain records of type InputFile.
38 // output_file contains records of type OutputFile.
39 // The files are processed in parallel using the number of threads specified by
40 // the ProcessRecordFileParameters.
41 // The file is processed in chunks of at most params.data_chunk_size values:
42 // read a chunk, apply function record_transformer() in parallel using
43 // params.thread_count threads, get the output values returned by each thread,
44 // and write them to file. Process the next chunk until there are no more values
45 // to read.
46 template <typename InputType, typename OutputType>
47 Status ProcessRecordFile(
48     const std::function<StatusOr<OutputType>(InputType)>& record_transformer,
49     const ProcessRecordFileParameters& params, absl::string_view input_file,
50     absl::string_view output_file,
51     const std::function<std::string(absl::string_view)>&
52         get_sorting_key_function = [](absl::string_view raw_record) {
53           return std::string(raw_record);
54         }) {
55   auto reader = std::unique_ptr<RecordReader>(RecordReader::GetRecordReader());
56   RETURN_IF_ERROR(reader->Open(input_file));
57 
58   auto writer = ShardingWriter<std::string>::Get(get_sorting_key_function);
59   writer->SetShardPrefix(output_file);
60 
61   std::string raw_record;
62   size_t num_records_read = 0;
63   // Process the file in chunks of at most data_chunk_size values: read a
64   // chunk, process it in parallel using the number of available threads, get
65   // the values returned by each thread, and write them to file.
66   // Process the next chunk until there are no more values to read.
67   ASSIGN_OR_RETURN(bool has_more, reader->HasMore());
68   while (has_more) {
69     // Read the next chunk to process in parallel.
70     num_records_read = 0;
71     std::vector<InputType> chunk;
72     while (num_records_read < params.data_chunk_size && has_more) {
73       RETURN_IF_ERROR(reader->Read(&raw_record));
74       chunk.push_back(ProtoUtils::FromString<InputType>(raw_record));
75       num_records_read++;
76       ASSIGN_OR_RETURN(has_more, reader->HasMore());
77     }
78 
79     // The max number of items each thread will process.
80     size_t per_thread_size =
81         (chunk.size() + params.thread_count - 1) / params.thread_count;
82 
83     // Stores the results of each thread.
84     // Each thread processes a portion of chunk.
85     std::vector<std::future<StatusOr<std::vector<OutputType>>>> futures;
86     for (uint32_t j = 0; j < params.thread_count; j++) {
87       size_t start = j * per_thread_size;
88       size_t end = std::min((j + 1) * per_thread_size, num_records_read);
89       // std::launch::async ensures multi-thread.
90       futures.push_back(std::async(
91           std::launch::async,
92           [&chunk, start, end,
93            record_transformer]() -> StatusOr<std::vector<OutputType>> {
94             std::vector<OutputType> processes_chunk;
95             for (size_t i = start; i < end; i++) {
96               ASSIGN_OR_RETURN(auto processed_record,
97                                record_transformer(chunk.at(i)));
98               processes_chunk.push_back(std::move(processed_record));
99             }
100             return processes_chunk;
101           }));
102     }
103 
104     // Write the processed values returned by each thread to file.
105     writer->SetShardPrefix(output_file);
106     int index = 0;
107     for (auto& future : futures) {
108       index++;
109       ASSIGN_OR_RETURN(auto records, future.get());
110       for (const auto& record : records) {
111         RETURN_IF_ERROR(writer->Write(ProtoUtils::ToString(record)));
112       }
113     }
114   }
115   RETURN_IF_ERROR(reader->Close());
116 
117   // Merge all the processed chunks into one output file and delete intermediate
118   // chunk files.
119   ASSIGN_OR_RETURN(auto shard_files, writer->Close());
120   ShardMerger<std::string> merger;
121   RETURN_IF_ERROR(
122       merger.Merge(get_sorting_key_function, shard_files, output_file));
123   RETURN_IF_ERROR(merger.Delete(shard_files));
124 
125   return OkStatus();
126 }
127 
128 }  // namespace private_join_and_compute::util::process_file_util
129 
130 #endif  // PRIVATE_JOIN_AND_COMPUTE_UTIL_PROCESS_RECORD_FILE_UTIL_H_
131