xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/data_util.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/data_util.h"
17 
18 #include <algorithm>
19 #include <cctype>
20 #include <fstream>
21 #include <limits>
22 #include <random>
23 #include <string>
24 #include <tuple>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/container/btree_set.h"
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_replace.h"
32 #include "absl/strings/string_view.h"
33 #include "private_join_and_compute/crypto/context.h"
34 #include "private_join_and_compute/util/status.inc"
35 
36 namespace private_join_and_compute {
37 namespace {
38 
39 static const char kAlphaNumericCharacters[] =
40     "1234567890qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM";
41 static const size_t kAlphaNumericSize = 62;
42 
43 // Creates a string of the specified length consistin of random letters and
44 // numbers.
GetRandomAlphaNumericString(size_t length)45 std::string GetRandomAlphaNumericString(size_t length) {
46   std::string output;
47   for (size_t i = 0; i < length; i++) {
48     std::string next_char(1,
49                           kAlphaNumericCharacters[rand() % kAlphaNumericSize]);
50     absl::StrAppend(&output, next_char);
51   }
52   return output;
53 }
54 
55 // Utility functions to convert a line to CSV format, and parse a CSV line into
56 // columns safely.
57 
strndup_with_new(const char * the_string,size_t max_length)58 char* strndup_with_new(const char* the_string, size_t max_length) {
59   if (the_string == nullptr) return nullptr;
60 
61   char* result = new char[max_length + 1];
62   result[max_length] = '\0';  // terminate the string because strncpy might not
63   return strncpy(result, the_string, max_length);
64 }
65 
SplitCSVLineWithDelimiter(char * line,char delimiter,std::vector<char * > * cols)66 void SplitCSVLineWithDelimiter(char* line, char delimiter,
67                                std::vector<char*>* cols) {
68   char* end_of_line = line + strlen(line);
69   char* end;
70   char* start;
71 
72   for (; line < end_of_line; line++) {
73     // Skip leading whitespace, unless said whitespace is the delimiter.
74     while (std::isspace(*line) && *line != delimiter) ++line;
75 
76     if (*line == '"' && delimiter == ',') {  // Quoted value...
77       start = ++line;
78       end = start;
79       for (; *line; line++) {
80         if (*line == '"') {
81           line++;
82           if (*line != '"')  // [""] is an escaped ["]
83             break;           // but just ["] is end of value
84         }
85         *end++ = *line;
86       }
87       // All characters after the closing quote and before the comma
88       // are ignored.
89       line = strchr(line, delimiter);
90       if (!line) line = end_of_line;
91     } else {
92       start = line;
93       line = strchr(line, delimiter);
94       if (!line) line = end_of_line;
95       // Skip all trailing whitespace, unless said whitespace is the delimiter.
96       for (end = line; end > start; --end) {
97         if (!std::isspace(end[-1]) || end[-1] == delimiter) break;
98       }
99     }
100     const bool need_another_column =
101         (*line == delimiter) && (line == end_of_line - 1);
102     *end = '\0';
103     cols->push_back(start);
104     // If line was something like [paul,] (comma is the last character
105     // and is not proceeded by whitespace or quote) then we are about
106     // to eliminate the last column (which is empty). This would be
107     // incorrect.
108     if (need_another_column) cols->push_back(end);
109 
110     assert(*line == '\0' || *line == delimiter);
111   }
112 }
113 
SplitCSVLineWithDelimiterForStrings(const std::string & line,char delimiter,std::vector<std::string> * cols)114 void SplitCSVLineWithDelimiterForStrings(const std::string& line,
115                                          char delimiter,
116                                          std::vector<std::string>* cols) {
117   // Unfortunately, the interface requires char* instead of const char*
118   // which requires copying the string.
119   char* cline = strndup_with_new(line.c_str(), line.size());
120   std::vector<char*> v;
121   SplitCSVLineWithDelimiter(cline, delimiter, &v);
122   for (char* str : v) {
123     cols->push_back(str);
124   }
125   delete[] cline;
126 }
127 
128 // Escapes a string for CSV file writing. By default, this will surround each
129 // string with double quotes, and escape each occurrence of a double quote by
130 // replacing it with 2 double quotes.
EscapeForCsv(absl::string_view input)131 std::string EscapeForCsv(absl::string_view input) {
132   return absl::StrCat("\"", absl::StrReplaceAll(input, {{"\"", "\"\""}}), "\"");
133 }
134 
135 }  // namespace
136 
SplitCsvLine(const std::string & line)137 std::vector<std::string> SplitCsvLine(const std::string& line) {
138   std::vector<std::string> cols;
139   SplitCSVLineWithDelimiterForStrings(line, ',', &cols);
140   return cols;
141 }
142 
GenerateRandomDatabases(int64_t server_data_size,int64_t client_data_size,int64_t intersection_size,int64_t max_associated_value)143 auto GenerateRandomDatabases(int64_t server_data_size, int64_t client_data_size,
144                              int64_t intersection_size,
145                              int64_t max_associated_value)
146     -> StatusOr<std::tuple<
147         std::vector<std::string>,
148         std::pair<std::vector<std::string>, std::vector<int64_t>>, int64_t>> {
149   // Check parameters
150   if (intersection_size < 0 || server_data_size < 0 || client_data_size < 0 ||
151       max_associated_value < 0) {
152     return InvalidArgumentError(
153         "GenerateRandomDatabases: Sizes cannot be negative.");
154   }
155   if (intersection_size > server_data_size ||
156       intersection_size > client_data_size) {
157     return InvalidArgumentError(
158         "GenerateRandomDatabases: intersection_size is larger than "
159         "client/server data size.");
160   }
161 
162   if (max_associated_value > 0 &&
163       intersection_size >
164           std::numeric_limits<int64_t>::max() / max_associated_value) {
165     return InvalidArgumentError(
166         "GenerateRandomDatabases: intersection_size * max_associated_value  is "
167         "larger than int64_t::max.");
168   }
169 
170   std::random_device rd;
171   std::mt19937 gen(rd());
172 
173   // Generate the random identifiers that are going to be in the intersection.
174   std::vector<std::string> common_identifiers;
175   common_identifiers.reserve(intersection_size);
176   for (int64_t i = 0; i < intersection_size; i++) {
177     common_identifiers.push_back(
178         GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
179   }
180 
181   // Generate remaining random identifiers for the server, and shuffle.
182   std::vector<std::string> server_identifiers = common_identifiers;
183   server_identifiers.reserve(server_data_size);
184   for (int64_t i = intersection_size; i < server_data_size; i++) {
185     server_identifiers.push_back(
186         GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
187   }
188   std::shuffle(server_identifiers.begin(), server_identifiers.end(), gen);
189 
190   // Generate remaining random identifiers for the client.
191   std::vector<std::string> client_identifiers = common_identifiers;
192   client_identifiers.reserve(client_data_size);
193   for (int64_t i = intersection_size; i < client_data_size; i++) {
194     client_identifiers.push_back(
195         GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
196   }
197   std::shuffle(client_identifiers.begin(), client_identifiers.end(), gen);
198 
199   absl::btree_set<std::string> server_identifiers_set(
200       server_identifiers.begin(), server_identifiers.end());
201 
202   // Generate associated values for the client, adding them to the intersection
203   // sum if the identifier is in common.
204   std::vector<int64_t> client_associated_values;
205   Context context;
206   BigNum associated_values_bound = context.CreateBigNum(max_associated_value);
207   client_associated_values.reserve(client_data_size);
208   int64_t intersection_sum = 0;
209   for (int64_t i = 0; i < client_data_size; i++) {
210     // Converting the associated value from BigNum to int64_t should never fail
211     // because associated_values_bound is less than int64_t::max.
212     int64_t associated_value =
213         context.GenerateRandLessThan(associated_values_bound)
214             .ToIntValue()
215             .value();
216     client_associated_values.push_back(associated_value);
217 
218     if (server_identifiers_set.count(client_identifiers[i]) > 0) {
219       intersection_sum += associated_value;
220     }
221   }
222 
223   // Return the output.
224   return std::make_tuple(std::move(server_identifiers),
225                          std::make_pair(std::move(client_identifiers),
226                                         std::move(client_associated_values)),
227                          intersection_sum);
228 }
229 
WriteServerDatasetToFile(const std::vector<std::string> & server_data,absl::string_view server_data_filename)230 Status WriteServerDatasetToFile(const std::vector<std::string>& server_data,
231                                 absl::string_view server_data_filename) {
232   // Open file.
233   std::ofstream server_data_file;
234   server_data_file.open(std::string(server_data_filename));
235   if (!server_data_file.is_open()) {
236     return InvalidArgumentError(absl::StrCat(
237         "WriteServerDatasetToFile: Couldn't open server data file: ",
238         server_data_filename));
239   }
240 
241   // Write each (escaped) line to file.
242   for (const auto& identifier : server_data) {
243     server_data_file << EscapeForCsv(identifier) << "\n";
244   }
245 
246   // Close file.
247   server_data_file.close();
248   if (server_data_file.fail()) {
249     return InternalError(
250         absl::StrCat("WriteServerDatasetToFile: Couldn't write to or close "
251                      "server data file: ",
252                      server_data_filename));
253   }
254 
255   return OkStatus();
256 }
257 
WriteClientDatasetToFile(const std::vector<std::string> & client_identifiers,const std::vector<int64_t> & client_associated_values,absl::string_view client_data_filename)258 Status WriteClientDatasetToFile(
259     const std::vector<std::string>& client_identifiers,
260     const std::vector<int64_t>& client_associated_values,
261     absl::string_view client_data_filename) {
262   if (client_associated_values.size() != client_identifiers.size()) {
263     return InvalidArgumentError(
264         "WriteClientDatasetToFile: there should be the same number of client "
265         "identifiers and associated values.");
266   }
267 
268   // Open file.
269   std::ofstream client_data_file;
270   client_data_file.open(std::string(client_data_filename));
271   if (!client_data_file.is_open()) {
272     return InvalidArgumentError(absl::StrCat(
273         "WriteClientDatasetToFile: Couldn't open client data file: ",
274         client_data_filename));
275   }
276 
277   // Write each (escaped) line to file.
278   for (size_t i = 0; i < client_identifiers.size(); i++) {
279     client_data_file << absl::StrCat(EscapeForCsv(client_identifiers[i]), ",",
280                                      client_associated_values[i])
281                      << "\n";
282   }
283 
284   // Close file.
285   client_data_file.close();
286   if (client_data_file.fail()) {
287     return InternalError(
288         absl::StrCat("WriteClientDatasetToFile: Couldn't write to or close "
289                      "client data file: ",
290                      client_data_filename));
291   }
292 
293   return OkStatus();
294 }
295 
ReadServerDatasetFromFile(absl::string_view server_data_filename)296 StatusOr<std::vector<std::string>> ReadServerDatasetFromFile(
297     absl::string_view server_data_filename) {
298   // Open file.
299   std::ifstream server_data_file;
300   server_data_file.open(std::string(server_data_filename));
301   if (!server_data_file.is_open()) {
302     return InvalidArgumentError(absl::StrCat(
303         "ReadServerDatasetFromFile: Couldn't open server data file: ",
304         server_data_filename));
305   }
306 
307   // Read each line from file (unescaping and splitting columns). Verify that
308   // each line contains a single column
309   std::vector<std::string> server_data;
310   std::string line;
311   int64_t line_number = 0;
312   while (std::getline(server_data_file, line)) {
313     std::vector<std::string> columns = SplitCsvLine(line);
314     if (columns.size() != 1) {
315       return InvalidArgumentError(absl::StrCat(
316           "ReadServerDatasetFromFile: Expected exactly 1 identifier per line, "
317           "but line ",
318           line_number, "has ", columns.size(),
319           " comma-separated items (file: ", server_data_filename, ")"));
320     }
321     server_data.push_back(columns[0]);
322     line_number++;
323   }
324 
325   // Close file.
326   server_data_file.close();
327   if (server_data_file.is_open()) {
328     return InternalError(absl::StrCat(
329         "ReadServerDatasetFromFile: Couldn't close server data file: ",
330         server_data_filename));
331   }
332 
333   return server_data;
334 }
335 
336 StatusOr<std::pair<std::vector<std::string>, std::vector<BigNum>>>
ReadClientDatasetFromFile(absl::string_view client_data_filename,Context * context)337 ReadClientDatasetFromFile(absl::string_view client_data_filename,
338                           Context* context) {
339   // Open file.
340   std::ifstream client_data_file;
341   client_data_file.open(std::string(client_data_filename));
342   if (!client_data_file.is_open()) {
343     return InvalidArgumentError(absl::StrCat(
344         "ReadClientDatasetFromFile: Couldn't open client data file: ",
345         client_data_filename));
346   }
347 
348   // Read each line from file (unescaping and splitting columns). Verify that
349   // each line contains two columns, and parse the second column into an
350   // associated value.
351   std::vector<std::string> client_identifiers;
352   std::vector<BigNum> client_associated_values;
353   std::string line;
354   int64_t line_number = 0;
355   while (std::getline(client_data_file, line)) {
356     std::vector<std::string> columns = SplitCsvLine(line);
357     if (columns.size() != 2) {
358       return InvalidArgumentError(absl::StrCat(
359           "ReadClientDatasetFromFile: Expected exactly 2 items per line, "
360           "but line ",
361           line_number, "has ", columns.size(),
362           " comma-separated items (file: ", client_data_filename, ")"));
363     }
364     client_identifiers.push_back(columns[0]);
365     int64_t parsed_associated_value;
366     if (!absl::SimpleAtoi(columns[1], &parsed_associated_value) ||
367         parsed_associated_value < 0) {
368       return InvalidArgumentError(
369           absl::StrCat("ReadClientDatasetFromFile: could not parse a "
370                        "nonnegative associated value at line number",
371                        line_number));
372     }
373     client_associated_values.push_back(
374         context->CreateBigNum(parsed_associated_value));
375     line_number++;
376   }
377 
378   // Close file.
379   client_data_file.close();
380   if (client_data_file.is_open()) {
381     return InternalError(absl::StrCat(
382         "ReadClientDatasetFromFile: Couldn't close client data file: ",
383         client_data_filename));
384   }
385 
386   return std::make_pair(std::move(client_identifiers),
387                         std::move(client_associated_values));
388 }
389 
390 }  // namespace private_join_and_compute
391