xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/client.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <iostream>
17 #include <memory>
18 #include <ostream>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/flags/flag.h"
23 #include "absl/flags/parse.h"
24 #include "absl/strings/str_cat.h"
25 #include "include/grpc/grpc_security_constants.h"
26 #include "include/grpcpp/channel.h"
27 #include "include/grpcpp/client_context.h"
28 #include "include/grpcpp/create_channel.h"
29 #include "include/grpcpp/grpcpp.h"
30 #include "include/grpcpp/security/credentials.h"
31 #include "include/grpcpp/support/status.h"
32 #include "private_join_and_compute/client_impl.h"
33 #include "private_join_and_compute/data_util.h"
34 #include "private_join_and_compute/private_join_and_compute.grpc.pb.h"
35 #include "private_join_and_compute/private_join_and_compute.pb.h"
36 #include "private_join_and_compute/protocol_client.h"
37 #include "private_join_and_compute/util/status.inc"
38 
39 ABSL_FLAG(std::string, port, "0.0.0.0:10501",
40           "Port on which to contact server");
41 ABSL_FLAG(std::string, client_data_file, "",
42           "The file from which to read the client database.");
43 ABSL_FLAG(
44     int32_t, paillier_modulus_size, 1536,
45     "The bit-length of the modulus to use for Paillier encryption. The modulus "
46     "will be the product of two safe primes, each of size "
47     "paillier_modulus_size/2.");
48 
49 namespace private_join_and_compute {
50 namespace {
51 
52 class InvokeServerHandleClientMessageSink : public MessageSink<ClientMessage> {
53  public:
InvokeServerHandleClientMessageSink(std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub)54   explicit InvokeServerHandleClientMessageSink(
55       std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub)
56       : stub_(std::move(stub)) {}
57 
58   ~InvokeServerHandleClientMessageSink() override = default;
59 
Send(const ClientMessage & message)60   Status Send(const ClientMessage& message) override {
61     ::grpc::ClientContext client_context;
62     ::grpc::Status grpc_status =
63         stub_->Handle(&client_context, message, &last_server_response_);
64     if (grpc_status.ok()) {
65       return OkStatus();
66     } else {
67       return InternalError(absl::StrCat(
68           "GrpcClientMessageSink: Failed to send message, error code: ",
69           grpc_status.error_code(),
70           ", error_message: ", grpc_status.error_message()));
71     }
72   }
73 
last_server_response()74   const ServerMessage& last_server_response() { return last_server_response_; }
75 
76  private:
77   std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub_;
78   ServerMessage last_server_response_;
79 };
80 
ExecuteProtocol()81 int ExecuteProtocol() {
82   ::private_join_and_compute::Context context;
83 
84   std::cout << "Client: Loading data..." << std::endl;
85   auto maybe_client_identifiers_and_associated_values =
86       ::private_join_and_compute::ReadClientDatasetFromFile(
87           absl::GetFlag(FLAGS_client_data_file), &context);
88   if (!maybe_client_identifiers_and_associated_values.ok()) {
89     std::cerr << "Client::ExecuteProtocol: failed "
90               << maybe_client_identifiers_and_associated_values.status()
91               << std::endl;
92     return 1;
93   }
94   auto client_identifiers_and_associated_values =
95       std::move(maybe_client_identifiers_and_associated_values.value());
96 
97   std::cout << "Client: Generating keys..." << std::endl;
98   std::unique_ptr<::private_join_and_compute::ProtocolClient> client =
99       std::make_unique<
100           ::private_join_and_compute::PrivateIntersectionSumProtocolClientImpl>(
101           &context, std::move(client_identifiers_and_associated_values.first),
102           std::move(client_identifiers_and_associated_values.second),
103           absl::GetFlag(FLAGS_paillier_modulus_size));
104 
105   // Consider grpc::SslServerCredentials if not running locally.
106   std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub =
107       PrivateJoinAndComputeRpc::NewStub(::grpc::CreateChannel(
108           absl::GetFlag(FLAGS_port), ::grpc::experimental::LocalCredentials(
109                                          grpc_local_connect_type::LOCAL_TCP)));
110   InvokeServerHandleClientMessageSink invoke_server_handle_message_sink(
111       std::move(stub));
112 
113   // Execute StartProtocol and wait for response from ServerRoundOne.
114   std::cout
115       << "Client: Starting the protocol." << std::endl
116       << "Client: Waiting for response and encrypted set from the server..."
117       << std::endl;
118   auto start_protocol_status =
119       client->StartProtocol(&invoke_server_handle_message_sink);
120   if (!start_protocol_status.ok()) {
121     std::cerr << "Client::ExecuteProtocol: failed to StartProtocol: "
122               << start_protocol_status << std::endl;
123     return 1;
124   }
125   ServerMessage server_round_one =
126       invoke_server_handle_message_sink.last_server_response();
127 
128   // Execute ClientRoundOne, and wait for response from ServerRoundTwo.
129   std::cout
130       << "Client: Received encrypted set from the server, double encrypting..."
131       << std::endl;
132   std::cout << "Client: Sending double encrypted server data and "
133                "single-encrypted client data to the server."
134             << std::endl
135             << "Client: Waiting for encrypted intersection sum..." << std::endl;
136   auto client_round_one_status =
137       client->Handle(server_round_one, &invoke_server_handle_message_sink);
138   if (!client_round_one_status.ok()) {
139     std::cerr << "Client::ExecuteProtocol: failed to ReEncryptSet: "
140               << client_round_one_status << std::endl;
141     return 1;
142   }
143 
144   // Execute ServerRoundTwo.
145   std::cout << "Client: Sending double encrypted server data and "
146                "single-encrypted client data to the server."
147             << std::endl
148             << "Client: Waiting for encrypted intersection sum..." << std::endl;
149   ServerMessage server_round_two =
150       invoke_server_handle_message_sink.last_server_response();
151 
152   // Compute the intersection size and sum.
153   std::cout << "Client: Received response from the server. Decrypting the "
154                "intersection-sum."
155             << std::endl;
156   auto intersection_size_and_sum_status =
157       client->Handle(server_round_two, &invoke_server_handle_message_sink);
158   if (!intersection_size_and_sum_status.ok()) {
159     std::cerr << "Client::ExecuteProtocol: failed to DecryptSum: "
160               << intersection_size_and_sum_status << std::endl;
161     return 1;
162   }
163 
164   // Output the result.
165   auto client_print_output_status = client->PrintOutput();
166   if (!client_print_output_status.ok()) {
167     std::cerr << "Client::ExecuteProtocol: failed to PrintOutput: "
168               << client_print_output_status << std::endl;
169     return 1;
170   }
171 
172   return 0;
173 }
174 
175 }  // namespace
176 }  // namespace private_join_and_compute
177 
main(int argc,char ** argv)178 int main(int argc, char** argv) {
179   absl::ParseCommandLine(argc, argv);
180 
181   return private_join_and_compute::ExecuteProtocol();
182 }
183