1 /*
2 * Copyright 2019 Google LLC.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * https://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "private_join_and_compute/client_impl.h"
17
18 #include <algorithm>
19 #include <iostream>
20 #include <iterator>
21 #include <memory>
22 #include <ostream>
23 #include <string>
24 #include <tuple>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/memory/memory.h"
29
30 namespace private_join_and_compute {
31
32 PrivateIntersectionSumProtocolClientImpl::
PrivateIntersectionSumProtocolClientImpl(Context * ctx,const std::vector<std::string> & elements,const std::vector<BigNum> & values,int32_t modulus_size)33 PrivateIntersectionSumProtocolClientImpl(
34 Context* ctx, const std::vector<std::string>& elements,
35 const std::vector<BigNum>& values, int32_t modulus_size)
36 : ctx_(ctx),
37 elements_(elements),
38 values_(values),
39 p_(ctx_->GenerateSafePrime(modulus_size / 2)),
40 q_(ctx_->GenerateSafePrime(modulus_size / 2)),
41 intersection_sum_(ctx->Zero()),
42 ec_cipher_(std::move(
43 ECCommutativeCipher::CreateWithNewKey(
44 NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SHA256)
45 .value())) {}
46
47 StatusOr<PrivateIntersectionSumClientMessage::ClientRoundOne>
ReEncryptSet(const PrivateIntersectionSumServerMessage::ServerRoundOne & message)48 PrivateIntersectionSumProtocolClientImpl::ReEncryptSet(
49 const PrivateIntersectionSumServerMessage::ServerRoundOne& message) {
50 private_paillier_ = std::make_unique<PrivatePaillier>(ctx_, p_, q_, 2);
51 BigNum pk = p_ * q_;
52 PrivateIntersectionSumClientMessage::ClientRoundOne result;
53 *result.mutable_public_key() = pk.ToBytes();
54 for (size_t i = 0; i < elements_.size(); i++) {
55 EncryptedElement* element = result.mutable_encrypted_set()->add_elements();
56 StatusOr<std::string> encrypted = ec_cipher_->Encrypt(elements_[i]);
57 if (!encrypted.ok()) {
58 return encrypted.status();
59 }
60 *element->mutable_element() = encrypted.value();
61 StatusOr<BigNum> value = private_paillier_->Encrypt(values_[i]);
62 if (!value.ok()) {
63 return value.status();
64 }
65 *element->mutable_associated_data() = value.value().ToBytes();
66 }
67
68 std::vector<EncryptedElement> reencrypted_set;
69 for (const EncryptedElement& element : message.encrypted_set().elements()) {
70 EncryptedElement reencrypted;
71 StatusOr<std::string> reenc = ec_cipher_->ReEncrypt(element.element());
72 if (!reenc.ok()) {
73 return reenc.status();
74 }
75 *reencrypted.mutable_element() = reenc.value();
76 reencrypted_set.push_back(reencrypted);
77 }
78 std::sort(reencrypted_set.begin(), reencrypted_set.end(),
79 [](const EncryptedElement& a, const EncryptedElement& b) {
80 return a.element() < b.element();
81 });
82 for (const EncryptedElement& element : reencrypted_set) {
83 *result.mutable_reencrypted_set()->add_elements() = element;
84 }
85
86 return result;
87 }
88
89 StatusOr<std::pair<int64_t, BigNum>>
DecryptSum(const PrivateIntersectionSumServerMessage::ServerRoundTwo & server_message)90 PrivateIntersectionSumProtocolClientImpl::DecryptSum(
91 const PrivateIntersectionSumServerMessage::ServerRoundTwo& server_message) {
92 if (private_paillier_ == nullptr) {
93 return InvalidArgumentError("Called DecryptSum before ReEncryptSet.");
94 }
95
96 StatusOr<BigNum> sum = private_paillier_->Decrypt(
97 ctx_->CreateBigNum(server_message.encrypted_sum()));
98 if (!sum.ok()) {
99 return sum.status();
100 }
101 return std::make_pair(server_message.intersection_size(), sum.value());
102 }
103
StartProtocol(MessageSink<ClientMessage> * client_message_sink)104 Status PrivateIntersectionSumProtocolClientImpl::StartProtocol(
105 MessageSink<ClientMessage>* client_message_sink) {
106 ClientMessage client_message;
107 *(client_message.mutable_private_intersection_sum_client_message()
108 ->mutable_start_protocol_request()) =
109 PrivateIntersectionSumClientMessage::StartProtocolRequest();
110 return client_message_sink->Send(client_message);
111 }
112
Handle(const ServerMessage & server_message,MessageSink<ClientMessage> * client_message_sink)113 Status PrivateIntersectionSumProtocolClientImpl::Handle(
114 const ServerMessage& server_message,
115 MessageSink<ClientMessage>* client_message_sink) {
116 if (protocol_finished()) {
117 return InvalidArgumentError(
118 "PrivateIntersectionSumProtocolClientImpl: Protocol is already "
119 "complete.");
120 }
121
122 // Check that the message is a PrivateIntersectionSum protocol message.
123 if (!server_message.has_private_intersection_sum_server_message()) {
124 return InvalidArgumentError(
125 "PrivateIntersectionSumProtocolClientImpl: Received a message for the "
126 "wrong protocol type");
127 }
128
129 if (server_message.private_intersection_sum_server_message()
130 .has_server_round_one()) {
131 // Handle the server round one message.
132 ClientMessage client_message;
133
134 auto maybe_client_round_one =
135 ReEncryptSet(server_message.private_intersection_sum_server_message()
136 .server_round_one());
137 if (!maybe_client_round_one.ok()) {
138 return maybe_client_round_one.status();
139 }
140 *(client_message.mutable_private_intersection_sum_client_message()
141 ->mutable_client_round_one()) =
142 std::move(maybe_client_round_one.value());
143 return client_message_sink->Send(client_message);
144 } else if (server_message.private_intersection_sum_server_message()
145 .has_server_round_two()) {
146 // Handle the server round two message.
147 auto maybe_result =
148 DecryptSum(server_message.private_intersection_sum_server_message()
149 .server_round_two());
150 if (!maybe_result.ok()) {
151 return maybe_result.status();
152 }
153 std::tie(intersection_size_, intersection_sum_) =
154 std::move(maybe_result.value());
155 // Mark the protocol as finished here.
156 protocol_finished_ = true;
157 return OkStatus();
158 }
159 // If none of the previous cases matched, we received the wrong kind of
160 // message.
161 return InvalidArgumentError(
162 "PrivateIntersectionSumProtocolClientImpl: Received a server message "
163 "of an unknown type.");
164 }
165
PrintOutput()166 Status PrivateIntersectionSumProtocolClientImpl::PrintOutput() {
167 if (!protocol_finished()) {
168 return InvalidArgumentError(
169 "PrivateIntersectionSumProtocolClientImpl: Not ready to print the "
170 "output yet.");
171 }
172 auto maybe_converted_intersection_sum = intersection_sum_.ToIntValue();
173 if (!maybe_converted_intersection_sum.ok()) {
174 return maybe_converted_intersection_sum.status();
175 }
176 std::cout << "Client: The intersection size is " << intersection_size_
177 << " and the intersection-sum is "
178 << maybe_converted_intersection_sum.value() << std::endl;
179 return OkStatus();
180 }
181
182 } // namespace private_join_and_compute
183