1 /*
2 * Copyright 2019 Google LLC.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * https://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "private_join_and_compute/data_util.h"
17
18 #include <algorithm>
19 #include <cctype>
20 #include <fstream>
21 #include <limits>
22 #include <random>
23 #include <string>
24 #include <tuple>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/container/btree_set.h"
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_replace.h"
32 #include "absl/strings/string_view.h"
33 #include "private_join_and_compute/crypto/context.h"
34 #include "private_join_and_compute/util/status.inc"
35
36 namespace private_join_and_compute {
37 namespace {
38
39 static const char kAlphaNumericCharacters[] =
40 "1234567890qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM";
41 static const size_t kAlphaNumericSize = 62;
42
43 // Creates a string of the specified length consistin of random letters and
44 // numbers.
GetRandomAlphaNumericString(size_t length)45 std::string GetRandomAlphaNumericString(size_t length) {
46 std::string output;
47 for (size_t i = 0; i < length; i++) {
48 std::string next_char(1,
49 kAlphaNumericCharacters[rand() % kAlphaNumericSize]);
50 absl::StrAppend(&output, next_char);
51 }
52 return output;
53 }
54
55 // Utility functions to convert a line to CSV format, and parse a CSV line into
56 // columns safely.
57
strndup_with_new(const char * the_string,size_t max_length)58 char* strndup_with_new(const char* the_string, size_t max_length) {
59 if (the_string == nullptr) return nullptr;
60
61 char* result = new char[max_length + 1];
62 result[max_length] = '\0'; // terminate the string because strncpy might not
63 return strncpy(result, the_string, max_length);
64 }
65
SplitCSVLineWithDelimiter(char * line,char delimiter,std::vector<char * > * cols)66 void SplitCSVLineWithDelimiter(char* line, char delimiter,
67 std::vector<char*>* cols) {
68 char* end_of_line = line + strlen(line);
69 char* end;
70 char* start;
71
72 for (; line < end_of_line; line++) {
73 // Skip leading whitespace, unless said whitespace is the delimiter.
74 while (std::isspace(*line) && *line != delimiter) ++line;
75
76 if (*line == '"' && delimiter == ',') { // Quoted value...
77 start = ++line;
78 end = start;
79 for (; *line; line++) {
80 if (*line == '"') {
81 line++;
82 if (*line != '"') // [""] is an escaped ["]
83 break; // but just ["] is end of value
84 }
85 *end++ = *line;
86 }
87 // All characters after the closing quote and before the comma
88 // are ignored.
89 line = strchr(line, delimiter);
90 if (!line) line = end_of_line;
91 } else {
92 start = line;
93 line = strchr(line, delimiter);
94 if (!line) line = end_of_line;
95 // Skip all trailing whitespace, unless said whitespace is the delimiter.
96 for (end = line; end > start; --end) {
97 if (!std::isspace(end[-1]) || end[-1] == delimiter) break;
98 }
99 }
100 const bool need_another_column =
101 (*line == delimiter) && (line == end_of_line - 1);
102 *end = '\0';
103 cols->push_back(start);
104 // If line was something like [paul,] (comma is the last character
105 // and is not proceeded by whitespace or quote) then we are about
106 // to eliminate the last column (which is empty). This would be
107 // incorrect.
108 if (need_another_column) cols->push_back(end);
109
110 assert(*line == '\0' || *line == delimiter);
111 }
112 }
113
SplitCSVLineWithDelimiterForStrings(const std::string & line,char delimiter,std::vector<std::string> * cols)114 void SplitCSVLineWithDelimiterForStrings(const std::string& line,
115 char delimiter,
116 std::vector<std::string>* cols) {
117 // Unfortunately, the interface requires char* instead of const char*
118 // which requires copying the string.
119 char* cline = strndup_with_new(line.c_str(), line.size());
120 std::vector<char*> v;
121 SplitCSVLineWithDelimiter(cline, delimiter, &v);
122 for (char* str : v) {
123 cols->push_back(str);
124 }
125 delete[] cline;
126 }
127
128 // Escapes a string for CSV file writing. By default, this will surround each
129 // string with double quotes, and escape each occurrence of a double quote by
130 // replacing it with 2 double quotes.
EscapeForCsv(absl::string_view input)131 std::string EscapeForCsv(absl::string_view input) {
132 return absl::StrCat("\"", absl::StrReplaceAll(input, {{"\"", "\"\""}}), "\"");
133 }
134
135 } // namespace
136
SplitCsvLine(const std::string & line)137 std::vector<std::string> SplitCsvLine(const std::string& line) {
138 std::vector<std::string> cols;
139 SplitCSVLineWithDelimiterForStrings(line, ',', &cols);
140 return cols;
141 }
142
GenerateRandomDatabases(int64_t server_data_size,int64_t client_data_size,int64_t intersection_size,int64_t max_associated_value)143 auto GenerateRandomDatabases(int64_t server_data_size, int64_t client_data_size,
144 int64_t intersection_size,
145 int64_t max_associated_value)
146 -> StatusOr<std::tuple<
147 std::vector<std::string>,
148 std::pair<std::vector<std::string>, std::vector<int64_t>>, int64_t>> {
149 // Check parameters
150 if (intersection_size < 0 || server_data_size < 0 || client_data_size < 0 ||
151 max_associated_value < 0) {
152 return InvalidArgumentError(
153 "GenerateRandomDatabases: Sizes cannot be negative.");
154 }
155 if (intersection_size > server_data_size ||
156 intersection_size > client_data_size) {
157 return InvalidArgumentError(
158 "GenerateRandomDatabases: intersection_size is larger than "
159 "client/server data size.");
160 }
161
162 if (max_associated_value > 0 &&
163 intersection_size >
164 std::numeric_limits<int64_t>::max() / max_associated_value) {
165 return InvalidArgumentError(
166 "GenerateRandomDatabases: intersection_size * max_associated_value is "
167 "larger than int64_t::max.");
168 }
169
170 std::random_device rd;
171 std::mt19937 gen(rd());
172
173 // Generate the random identifiers that are going to be in the intersection.
174 std::vector<std::string> common_identifiers;
175 common_identifiers.reserve(intersection_size);
176 for (int64_t i = 0; i < intersection_size; i++) {
177 common_identifiers.push_back(
178 GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
179 }
180
181 // Generate remaining random identifiers for the server, and shuffle.
182 std::vector<std::string> server_identifiers = common_identifiers;
183 server_identifiers.reserve(server_data_size);
184 for (int64_t i = intersection_size; i < server_data_size; i++) {
185 server_identifiers.push_back(
186 GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
187 }
188 std::shuffle(server_identifiers.begin(), server_identifiers.end(), gen);
189
190 // Generate remaining random identifiers for the client.
191 std::vector<std::string> client_identifiers = common_identifiers;
192 client_identifiers.reserve(client_data_size);
193 for (int64_t i = intersection_size; i < client_data_size; i++) {
194 client_identifiers.push_back(
195 GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
196 }
197 std::shuffle(client_identifiers.begin(), client_identifiers.end(), gen);
198
199 absl::btree_set<std::string> server_identifiers_set(
200 server_identifiers.begin(), server_identifiers.end());
201
202 // Generate associated values for the client, adding them to the intersection
203 // sum if the identifier is in common.
204 std::vector<int64_t> client_associated_values;
205 Context context;
206 BigNum associated_values_bound = context.CreateBigNum(max_associated_value);
207 client_associated_values.reserve(client_data_size);
208 int64_t intersection_sum = 0;
209 for (int64_t i = 0; i < client_data_size; i++) {
210 // Converting the associated value from BigNum to int64_t should never fail
211 // because associated_values_bound is less than int64_t::max.
212 int64_t associated_value =
213 context.GenerateRandLessThan(associated_values_bound)
214 .ToIntValue()
215 .value();
216 client_associated_values.push_back(associated_value);
217
218 if (server_identifiers_set.count(client_identifiers[i]) > 0) {
219 intersection_sum += associated_value;
220 }
221 }
222
223 // Return the output.
224 return std::make_tuple(std::move(server_identifiers),
225 std::make_pair(std::move(client_identifiers),
226 std::move(client_associated_values)),
227 intersection_sum);
228 }
229
WriteServerDatasetToFile(const std::vector<std::string> & server_data,absl::string_view server_data_filename)230 Status WriteServerDatasetToFile(const std::vector<std::string>& server_data,
231 absl::string_view server_data_filename) {
232 // Open file.
233 std::ofstream server_data_file;
234 server_data_file.open(std::string(server_data_filename));
235 if (!server_data_file.is_open()) {
236 return InvalidArgumentError(absl::StrCat(
237 "WriteServerDatasetToFile: Couldn't open server data file: ",
238 server_data_filename));
239 }
240
241 // Write each (escaped) line to file.
242 for (const auto& identifier : server_data) {
243 server_data_file << EscapeForCsv(identifier) << "\n";
244 }
245
246 // Close file.
247 server_data_file.close();
248 if (server_data_file.fail()) {
249 return InternalError(
250 absl::StrCat("WriteServerDatasetToFile: Couldn't write to or close "
251 "server data file: ",
252 server_data_filename));
253 }
254
255 return OkStatus();
256 }
257
WriteClientDatasetToFile(const std::vector<std::string> & client_identifiers,const std::vector<int64_t> & client_associated_values,absl::string_view client_data_filename)258 Status WriteClientDatasetToFile(
259 const std::vector<std::string>& client_identifiers,
260 const std::vector<int64_t>& client_associated_values,
261 absl::string_view client_data_filename) {
262 if (client_associated_values.size() != client_identifiers.size()) {
263 return InvalidArgumentError(
264 "WriteClientDatasetToFile: there should be the same number of client "
265 "identifiers and associated values.");
266 }
267
268 // Open file.
269 std::ofstream client_data_file;
270 client_data_file.open(std::string(client_data_filename));
271 if (!client_data_file.is_open()) {
272 return InvalidArgumentError(absl::StrCat(
273 "WriteClientDatasetToFile: Couldn't open client data file: ",
274 client_data_filename));
275 }
276
277 // Write each (escaped) line to file.
278 for (size_t i = 0; i < client_identifiers.size(); i++) {
279 client_data_file << absl::StrCat(EscapeForCsv(client_identifiers[i]), ",",
280 client_associated_values[i])
281 << "\n";
282 }
283
284 // Close file.
285 client_data_file.close();
286 if (client_data_file.fail()) {
287 return InternalError(
288 absl::StrCat("WriteClientDatasetToFile: Couldn't write to or close "
289 "client data file: ",
290 client_data_filename));
291 }
292
293 return OkStatus();
294 }
295
ReadServerDatasetFromFile(absl::string_view server_data_filename)296 StatusOr<std::vector<std::string>> ReadServerDatasetFromFile(
297 absl::string_view server_data_filename) {
298 // Open file.
299 std::ifstream server_data_file;
300 server_data_file.open(std::string(server_data_filename));
301 if (!server_data_file.is_open()) {
302 return InvalidArgumentError(absl::StrCat(
303 "ReadServerDatasetFromFile: Couldn't open server data file: ",
304 server_data_filename));
305 }
306
307 // Read each line from file (unescaping and splitting columns). Verify that
308 // each line contains a single column
309 std::vector<std::string> server_data;
310 std::string line;
311 int64_t line_number = 0;
312 while (std::getline(server_data_file, line)) {
313 std::vector<std::string> columns = SplitCsvLine(line);
314 if (columns.size() != 1) {
315 return InvalidArgumentError(absl::StrCat(
316 "ReadServerDatasetFromFile: Expected exactly 1 identifier per line, "
317 "but line ",
318 line_number, "has ", columns.size(),
319 " comma-separated items (file: ", server_data_filename, ")"));
320 }
321 server_data.push_back(columns[0]);
322 line_number++;
323 }
324
325 // Close file.
326 server_data_file.close();
327 if (server_data_file.is_open()) {
328 return InternalError(absl::StrCat(
329 "ReadServerDatasetFromFile: Couldn't close server data file: ",
330 server_data_filename));
331 }
332
333 return server_data;
334 }
335
336 StatusOr<std::pair<std::vector<std::string>, std::vector<BigNum>>>
ReadClientDatasetFromFile(absl::string_view client_data_filename,Context * context)337 ReadClientDatasetFromFile(absl::string_view client_data_filename,
338 Context* context) {
339 // Open file.
340 std::ifstream client_data_file;
341 client_data_file.open(std::string(client_data_filename));
342 if (!client_data_file.is_open()) {
343 return InvalidArgumentError(absl::StrCat(
344 "ReadClientDatasetFromFile: Couldn't open client data file: ",
345 client_data_filename));
346 }
347
348 // Read each line from file (unescaping and splitting columns). Verify that
349 // each line contains two columns, and parse the second column into an
350 // associated value.
351 std::vector<std::string> client_identifiers;
352 std::vector<BigNum> client_associated_values;
353 std::string line;
354 int64_t line_number = 0;
355 while (std::getline(client_data_file, line)) {
356 std::vector<std::string> columns = SplitCsvLine(line);
357 if (columns.size() != 2) {
358 return InvalidArgumentError(absl::StrCat(
359 "ReadClientDatasetFromFile: Expected exactly 2 items per line, "
360 "but line ",
361 line_number, "has ", columns.size(),
362 " comma-separated items (file: ", client_data_filename, ")"));
363 }
364 client_identifiers.push_back(columns[0]);
365 int64_t parsed_associated_value;
366 if (!absl::SimpleAtoi(columns[1], &parsed_associated_value) ||
367 parsed_associated_value < 0) {
368 return InvalidArgumentError(
369 absl::StrCat("ReadClientDatasetFromFile: could not parse a "
370 "nonnegative associated value at line number",
371 line_number));
372 }
373 client_associated_values.push_back(
374 context->CreateBigNum(parsed_associated_value));
375 line_number++;
376 }
377
378 // Close file.
379 client_data_file.close();
380 if (client_data_file.is_open()) {
381 return InternalError(absl::StrCat(
382 "ReadClientDatasetFromFile: Couldn't close client data file: ",
383 client_data_filename));
384 }
385
386 return std::make_pair(std::move(client_identifiers),
387 std::move(client_associated_values));
388 }
389
390 } // namespace private_join_and_compute
391