xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/client_impl.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/client_impl.h"
17 
18 #include <algorithm>
19 #include <iostream>
20 #include <iterator>
21 #include <memory>
22 #include <ostream>
23 #include <string>
24 #include <tuple>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/memory/memory.h"
29 
30 namespace private_join_and_compute {
31 
32 PrivateIntersectionSumProtocolClientImpl::
PrivateIntersectionSumProtocolClientImpl(Context * ctx,const std::vector<std::string> & elements,const std::vector<BigNum> & values,int32_t modulus_size)33     PrivateIntersectionSumProtocolClientImpl(
34         Context* ctx, const std::vector<std::string>& elements,
35         const std::vector<BigNum>& values, int32_t modulus_size)
36     : ctx_(ctx),
37       elements_(elements),
38       values_(values),
39       p_(ctx_->GenerateSafePrime(modulus_size / 2)),
40       q_(ctx_->GenerateSafePrime(modulus_size / 2)),
41       intersection_sum_(ctx->Zero()),
42       ec_cipher_(std::move(
43           ECCommutativeCipher::CreateWithNewKey(
44               NID_X9_62_prime256v1, ECCommutativeCipher::HashType::SHA256)
45               .value())) {}
46 
47 StatusOr<PrivateIntersectionSumClientMessage::ClientRoundOne>
ReEncryptSet(const PrivateIntersectionSumServerMessage::ServerRoundOne & message)48 PrivateIntersectionSumProtocolClientImpl::ReEncryptSet(
49     const PrivateIntersectionSumServerMessage::ServerRoundOne& message) {
50   private_paillier_ = std::make_unique<PrivatePaillier>(ctx_, p_, q_, 2);
51   BigNum pk = p_ * q_;
52   PrivateIntersectionSumClientMessage::ClientRoundOne result;
53   *result.mutable_public_key() = pk.ToBytes();
54   for (size_t i = 0; i < elements_.size(); i++) {
55     EncryptedElement* element = result.mutable_encrypted_set()->add_elements();
56     StatusOr<std::string> encrypted = ec_cipher_->Encrypt(elements_[i]);
57     if (!encrypted.ok()) {
58       return encrypted.status();
59     }
60     *element->mutable_element() = encrypted.value();
61     StatusOr<BigNum> value = private_paillier_->Encrypt(values_[i]);
62     if (!value.ok()) {
63       return value.status();
64     }
65     *element->mutable_associated_data() = value.value().ToBytes();
66   }
67 
68   std::vector<EncryptedElement> reencrypted_set;
69   for (const EncryptedElement& element : message.encrypted_set().elements()) {
70     EncryptedElement reencrypted;
71     StatusOr<std::string> reenc = ec_cipher_->ReEncrypt(element.element());
72     if (!reenc.ok()) {
73       return reenc.status();
74     }
75     *reencrypted.mutable_element() = reenc.value();
76     reencrypted_set.push_back(reencrypted);
77   }
78   std::sort(reencrypted_set.begin(), reencrypted_set.end(),
79             [](const EncryptedElement& a, const EncryptedElement& b) {
80               return a.element() < b.element();
81             });
82   for (const EncryptedElement& element : reencrypted_set) {
83     *result.mutable_reencrypted_set()->add_elements() = element;
84   }
85 
86   return result;
87 }
88 
89 StatusOr<std::pair<int64_t, BigNum>>
DecryptSum(const PrivateIntersectionSumServerMessage::ServerRoundTwo & server_message)90 PrivateIntersectionSumProtocolClientImpl::DecryptSum(
91     const PrivateIntersectionSumServerMessage::ServerRoundTwo& server_message) {
92   if (private_paillier_ == nullptr) {
93     return InvalidArgumentError("Called DecryptSum before ReEncryptSet.");
94   }
95 
96   StatusOr<BigNum> sum = private_paillier_->Decrypt(
97       ctx_->CreateBigNum(server_message.encrypted_sum()));
98   if (!sum.ok()) {
99     return sum.status();
100   }
101   return std::make_pair(server_message.intersection_size(), sum.value());
102 }
103 
StartProtocol(MessageSink<ClientMessage> * client_message_sink)104 Status PrivateIntersectionSumProtocolClientImpl::StartProtocol(
105     MessageSink<ClientMessage>* client_message_sink) {
106   ClientMessage client_message;
107   *(client_message.mutable_private_intersection_sum_client_message()
108         ->mutable_start_protocol_request()) =
109       PrivateIntersectionSumClientMessage::StartProtocolRequest();
110   return client_message_sink->Send(client_message);
111 }
112 
Handle(const ServerMessage & server_message,MessageSink<ClientMessage> * client_message_sink)113 Status PrivateIntersectionSumProtocolClientImpl::Handle(
114     const ServerMessage& server_message,
115     MessageSink<ClientMessage>* client_message_sink) {
116   if (protocol_finished()) {
117     return InvalidArgumentError(
118         "PrivateIntersectionSumProtocolClientImpl: Protocol is already "
119         "complete.");
120   }
121 
122   // Check that the message is a PrivateIntersectionSum protocol message.
123   if (!server_message.has_private_intersection_sum_server_message()) {
124     return InvalidArgumentError(
125         "PrivateIntersectionSumProtocolClientImpl: Received a message for the "
126         "wrong protocol type");
127   }
128 
129   if (server_message.private_intersection_sum_server_message()
130           .has_server_round_one()) {
131     // Handle the server round one message.
132     ClientMessage client_message;
133 
134     auto maybe_client_round_one =
135         ReEncryptSet(server_message.private_intersection_sum_server_message()
136                          .server_round_one());
137     if (!maybe_client_round_one.ok()) {
138       return maybe_client_round_one.status();
139     }
140     *(client_message.mutable_private_intersection_sum_client_message()
141           ->mutable_client_round_one()) =
142         std::move(maybe_client_round_one.value());
143     return client_message_sink->Send(client_message);
144   } else if (server_message.private_intersection_sum_server_message()
145                  .has_server_round_two()) {
146     // Handle the server round two message.
147     auto maybe_result =
148         DecryptSum(server_message.private_intersection_sum_server_message()
149                        .server_round_two());
150     if (!maybe_result.ok()) {
151       return maybe_result.status();
152     }
153     std::tie(intersection_size_, intersection_sum_) =
154         std::move(maybe_result.value());
155     // Mark the protocol as finished here.
156     protocol_finished_ = true;
157     return OkStatus();
158   }
159   // If none of the previous cases matched, we received the wrong kind of
160   // message.
161   return InvalidArgumentError(
162       "PrivateIntersectionSumProtocolClientImpl: Received a server message "
163       "of an unknown type.");
164 }
165 
PrintOutput()166 Status PrivateIntersectionSumProtocolClientImpl::PrintOutput() {
167   if (!protocol_finished()) {
168     return InvalidArgumentError(
169         "PrivateIntersectionSumProtocolClientImpl: Not ready to print the "
170         "output yet.");
171   }
172   auto maybe_converted_intersection_sum = intersection_sum_.ToIntValue();
173   if (!maybe_converted_intersection_sum.ok()) {
174     return maybe_converted_intersection_sum.status();
175   }
176   std::cout << "Client: The intersection size is " << intersection_size_
177             << " and the intersection-sum is "
178             << maybe_converted_intersection_sum.value() << std::endl;
179   return OkStatus();
180 }
181 
182 }  // namespace private_join_and_compute
183