xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/client_impl.h (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_CLIENT_IMPL_H_
17 #define PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_CLIENT_IMPL_H_
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "private_join_and_compute/crypto/context.h"
25 #include "private_join_and_compute/crypto/ec_commutative_cipher.h"
26 #include "private_join_and_compute/crypto/paillier.h"
27 #include "private_join_and_compute/match.pb.h"
28 #include "private_join_and_compute/message_sink.h"
29 #include "private_join_and_compute/private_intersection_sum.pb.h"
30 #include "private_join_and_compute/private_join_and_compute.pb.h"
31 #include "private_join_and_compute/protocol_client.h"
32 #include "private_join_and_compute/util/status.inc"
33 
34 namespace private_join_and_compute {
35 
36 // This class represents the "client" part of the intersection-sum protocol,
37 // which supplies the associated values that will be used to compute the sum.
38 // This is the party that will receive the sum as output.
39 class PrivateIntersectionSumProtocolClientImpl : public ProtocolClient {
40  public:
41   PrivateIntersectionSumProtocolClientImpl(
42       Context* ctx, const std::vector<std::string>& elements,
43       const std::vector<BigNum>& values, int32_t modulus_size);
44 
45   // Generates the StartProtocol message and sends it on the message sink.
46   Status StartProtocol(
47       MessageSink<ClientMessage>* client_message_sink) override;
48 
49   // Executes the next Client round and creates a new server request, which must
50   // be sent to the server unless the protocol is finished.
51   //
52   // If the ServerMessage is ServerRoundOne, a ClientRoundOne will be sent on
53   // the message sink, containing the encrypted client identifiers and
54   // associated values, and the re-encrypted and shuffled server identifiers.
55   //
56   // If the ServerMessage is ServerRoundTwo, nothing will be sent on
57   // the message sink, and the client will internally store the intersection sum
58   // and size. The intersection sum and size can be retrieved either through
59   // accessors, or by calling PrintOutput.
60   //
61   // Fails with InvalidArgument if the message is not a
62   // PrivateIntersectionSumServerMessage of the expected round, or if the
63   // message is otherwise not as expected. Forwards all other failures
64   // encountered.
65   Status Handle(const ServerMessage& server_message,
66                 MessageSink<ClientMessage>* client_message_sink) override;
67 
68   // Prints the result, namely the intersection size and the intersection sum.
69   Status PrintOutput() override;
70 
protocol_finished()71   bool protocol_finished() override { return protocol_finished_; }
72 
73   // Utility functions for testing.
intersection_size()74   int64_t intersection_size() const { return intersection_size_; }
intersection_sum()75   const BigNum& intersection_sum() const { return intersection_sum_; }
76 
77  private:
78   // The server sends the first message of the protocol, which contains its
79   // encrypted set.  This party then re-encrypts that set and replies with the
80   // reencrypted values and its own encrypted set.
81   StatusOr<PrivateIntersectionSumClientMessage::ClientRoundOne> ReEncryptSet(
82       const PrivateIntersectionSumServerMessage::ServerRoundOne&
83           server_message);
84 
85   // After the server computes the intersection-sum, it will send it back to
86   // this party for decryption, together with the intersection_size. This party
87   // will decrypt and output the intersection sum and intersection size.
88   StatusOr<std::pair<int64_t, BigNum>> DecryptSum(
89       const PrivateIntersectionSumServerMessage::ServerRoundTwo&
90           server_message);
91 
92   Context* ctx_;  // not owned
93   std::vector<std::string> elements_;
94   std::vector<BigNum> values_;
95 
96   // The Paillier private key
97   BigNum p_, q_;
98 
99   // These values will hold the intersection sum and size when the protocol has
100   // been completed.
101   int64_t intersection_size_ = 0;
102   BigNum intersection_sum_;
103 
104   std::unique_ptr<ECCommutativeCipher> ec_cipher_;
105   std::unique_ptr<PrivatePaillier> private_paillier_;
106 
107   bool protocol_finished_ = false;
108 };
109 
110 }  // namespace private_join_and_compute
111 
112 #endif  // PRIVATE_JOIN_AND_COMPUTE_PRIVATE_INTERSECTION_SUM_CLIENT_IMPL_H_
113