xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server_protocol_impl.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_
18 #define FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_
19 
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "fcp/secagg/server/experiments_interface.h"
28 #include "fcp/secagg/server/secagg_scheduler.h"
29 #include "fcp/secagg/server/secagg_server_enums.pb.h"
30 #include "fcp/secagg/server/secagg_server_metrics_listener.h"
31 #include "fcp/secagg/server/secret_sharing_graph.h"
32 #include "fcp/secagg/server/send_to_clients_interface.h"
33 #include "fcp/secagg/shared/aes_prng_factory.h"
34 #include "fcp/secagg/shared/compute_session_id.h"
35 #include "fcp/secagg/shared/ecdh_key_agreement.h"
36 #include "fcp/secagg/shared/ecdh_keys.h"
37 #include "fcp/secagg/shared/input_vector_specification.h"
38 #include "fcp/secagg/shared/secagg_messages.pb.h"
39 #include "fcp/secagg/shared/secagg_vector.h"
40 #include "fcp/secagg/shared/shamir_secret_sharing.h"
41 
42 namespace fcp {
43 namespace secagg {
44 
45 // Interface that describes internal implementation of SecAgg protocol.
46 //
47 // The general design is the following
48 //
49 // +--------------+    +-------------------+    +--------------------------+
50 // | SecAggServer |--->| SecAggServerState |--->| SecAggServerProtocolImpl |
51 // +--------------+    +-------------------+    +--------------------------+
52 //                                ^                          ^
53 //                               /-\                        /-\
54 //                                |                          |
55 //                     +-------------------+    +------------------------+
56 //                     | Specific State    |    | Specific protocol impl |
57 //                     +-------------------+    +------------------------+
58 //
59 // Specific states implement logic specific to each logic SecAgg state, such as
60 // R0AdvertiseKeys or PrngRunning, while specific protocol implementation is
61 // shared between all states and is responsible for encapsulating the data
62 // for the protocol and providing methods for manipulating the data.
63 //
64 
65 class SecAggServerProtocolImpl {
66  public:
67   explicit SecAggServerProtocolImpl(
68       std::unique_ptr<SecretSharingGraph> graph,
69       int minimum_number_of_clients_to_proceed,
70       std::unique_ptr<SecAggServerMetricsListener> metrics,
71       std::unique_ptr<AesPrngFactory> prng_factory,
72       SendToClientsInterface* sender,
73       std::unique_ptr<SecAggScheduler> scheduler,
74       std::vector<ClientStatus> client_statuses,
75       std::unique_ptr<ExperimentsInterface> experiments = nullptr);
76   virtual ~SecAggServerProtocolImpl() = default;
77 
78   SecAggServerProtocolImpl(const SecAggServerProtocolImpl& other) = delete;
79   SecAggServerProtocolImpl& operator=(const SecAggServerProtocolImpl& other) =
80       delete;
81 
82   // Returns server variant for this protocol implementation.
83   virtual ServerVariant server_variant() const = 0;
84 
85   // Returns the graph that represents the cohort of clients.
secret_sharing_graph()86   inline const SecretSharingGraph* secret_sharing_graph() const {
87     return secret_sharing_graph_.get();
88   }
89 
90   // Returns the minimum threshold number of clients that need to send valid
91   // responses in order for the protocol to proceed from one round to the next.
minimum_number_of_clients_to_proceed()92   inline int minimum_number_of_clients_to_proceed() const {
93     return minimum_number_of_clients_to_proceed_;
94   }
95 
96   // Returns the callback interface for recording metrics.
metrics()97   inline SecAggServerMetricsListener* metrics() const { return metrics_.get(); }
98 
99   // Returns a reference to an instance of a subclass of AesPrngFactory.
prng_factory()100   inline AesPrngFactory* prng_factory() const { return prng_factory_.get(); }
101 
102   // Returns the callback interface for sending protocol buffer messages to the
103   // client.
sender()104   inline SendToClientsInterface* sender() const { return sender_; }
105 
106   // Returns the scheduler for scheduling parallel computation tasks and
107   // callbacks.
scheduler()108   inline SecAggScheduler* scheduler() const { return scheduler_.get(); }
109 
110   // Returns the experiments
experiments()111   inline ExperimentsInterface* experiments() const {
112     return experiments_.get();
113   }
114 
115   // Getting or setting the protocol result.
116   //
117   // TODO(team): SetResult should not be needed (except for testing) once
118   // PRNG computation is moved into the protocol implementation.
119   void SetResult(std::unique_ptr<SecAggVectorMap> result);
120   std::unique_ptr<SecAggVectorMap> TakeResult();
121 
122   // Gets the client status.
client_status(uint32_t client_id)123   inline const ClientStatus& client_status(uint32_t client_id) const {
124     return client_statuses_.at(client_id);
125   }
126 
127   // Sets the client status.
set_client_status(uint32_t client_id,ClientStatus status)128   inline void set_client_status(uint32_t client_id, ClientStatus status) {
129     client_statuses_[client_id] = status;
130   }
131 
132   // Gets the number of clients that the protocol starts with.
total_number_of_clients()133   inline size_t total_number_of_clients() const {
134     return total_number_of_clients_;
135   }
136 
137   // Returns the number of neighbors of each client.
number_of_neighbors()138   inline const int number_of_neighbors() const {
139     return secret_sharing_graph()->GetDegree();
140   }
141 
142   // Returns the minimum number of neighbors of a client that must not drop-out
143   // for that client's contribution to be included in the sum. This corresponds
144   // to the threshold in the shamir secret sharing of self and pairwise masks.
minimum_surviving_neighbors_for_reconstruction()145   inline const int minimum_surviving_neighbors_for_reconstruction() const {
146     return secret_sharing_graph()->GetThreshold();
147   }
148 
149   // Returns client_id's ith neighbor.
150   // This function assumes that 0 <= i < number_of_neighbors() and will throw a
151   // runtime error if that's not the case
GetNeighbor(int client_id,int i)152   inline const int GetNeighbor(int client_id, int i) const {
153     return secret_sharing_graph()->GetNeighbor(client_id, i);
154   }
155 
156   // Returns the index of client_id_2 in the list of neighbors of client_id_1,
157   // if present
GetNeighborIndex(int client_id_1,int client_id_2)158   inline const std::optional<int> GetNeighborIndex(int client_id_1,
159                                                    int client_id_2) const {
160     return secret_sharing_graph()->GetNeighborIndex(client_id_1, client_id_2);
161   }
162 
163   // Returns the index of client_id_2 in the list of neighbors of client_id_1
164   // This function assumes that client_id_1 and client_id_2 are neighbors, and
165   // will throw a runtime error if that's not the case
GetNeighborIndexOrDie(int client_id_1,int client_id_2)166   inline const int GetNeighborIndexOrDie(int client_id_1,
167                                          int client_id_2) const {
168     auto index =
169         secret_sharing_graph()->GetNeighborIndex(client_id_1, client_id_2);
170     FCP_CHECK(index.has_value());
171     return index.value();
172   }
173 
174   // Returns true if clients client_id_1 and client_id_1 are neighbors, else
175   // false.
AreNeighbors(int client_id_1,int client_id_2)176   inline const bool AreNeighbors(int client_id_1, int client_id_2) const {
177     return secret_sharing_graph()->AreNeighbors(client_id_1, client_id_2);
178   }
179 
180   // Returns true if client_id_1 is an outgoing neighbor of client_id_2, else
181   // false.
IsOutgoingNeighbor(int client_id_1,int client_id_2)182   inline const bool IsOutgoingNeighbor(int client_id_1, int client_id_2) const {
183     return secret_sharing_graph()->IsOutgoingNeighbor(client_id_1, client_id_2);
184   }
185 
SetPairwisePublicKeys(uint32_t client_id,const EcdhPublicKey & pairwise_key)186   inline void SetPairwisePublicKeys(uint32_t client_id,
187                                     const EcdhPublicKey& pairwise_key) {
188     pairwise_public_keys_[client_id] = pairwise_key;
189   }
190 
pairwise_public_keys(uint32_t client_id)191   inline const EcdhPublicKey& pairwise_public_keys(uint32_t client_id) const {
192     return pairwise_public_keys_[client_id];
193   }
194 
session_id()195   inline const SessionId& session_id() const {
196     FCP_CHECK(session_id_ != nullptr);
197     return *session_id_;
198   }
199 
set_session_id(std::unique_ptr<SessionId> session_id)200   void set_session_id(std::unique_ptr<SessionId> session_id) {
201     FCP_CHECK(session_id != nullptr);
202     session_id_ = std::move(session_id);
203   }
204 
205   // TODO(team): Review whether getters and setters below are needed.
206   // Most of these fields are needed only for testing.
207 
set_pairwise_shamir_share_table(std::unique_ptr<absl::flat_hash_map<uint32_t,std::vector<ShamirShare>>> pairwise_shamir_share_table)208   void set_pairwise_shamir_share_table(
209       std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>
210           pairwise_shamir_share_table) {
211     pairwise_shamir_share_table_ = std::move(pairwise_shamir_share_table);
212   }
213 
set_self_shamir_share_table(std::unique_ptr<absl::flat_hash_map<uint32_t,std::vector<ShamirShare>>> self_shamir_share_table)214   void set_self_shamir_share_table(
215       std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>
216           self_shamir_share_table) {
217     self_shamir_share_table_ = std::move(self_shamir_share_table);
218   }
219 
220   // ---------------------------------------------------------------------------
221   // Round 0 methods
222   // ---------------------------------------------------------------------------
223 
224   // Sets the public key pairs for a client.
225   Status HandleAdvertiseKeys(uint32_t client_id,
226                              const AdvertiseKeys& advertise_keys);
227 
228   // Erases public key pairs for a client.
229   void ErasePublicKeysForClient(uint32_t client_id);
230 
231   // Compute session ID based on public key pairs advertised by clients.
232   void ComputeSessionId();
233 
234   // This method allows a protocol implementation to populate fields that are
235   // common to the ShareKeysRequest sent to all clients.
236   virtual Status InitializeShareKeysRequest(
237       ShareKeysRequest* request) const = 0;
238 
239   // Prepares ShareKeysRequest message to send to the client.
240   // This method will update fields in the request as needed, but will not clear
241   // any fields that are not specific to the share keys request for the specific
242   // client.  The caller can therefore set up a single ShareKeysRequest object,
243   // populate fields that will be common to all clients, and repeatedly call
244   // this method to set the client-specific fields before serializing the
245   // message and sending it.
246   void PrepareShareKeysRequestForClient(uint32_t client_id,
247                                         ShareKeysRequest* request) const;
248 
249   // Clears all pairs of public keys.
250   void ClearPairsOfPublicKeys();
251 
252   // ---------------------------------------------------------------------------
253   // Round 1 methods
254   // ---------------------------------------------------------------------------
255 
256   // Sets the encrypted shares received from a client.
257   Status HandleShareKeysResponse(uint32_t client_id,
258                                  const ShareKeysResponse& share_keys_response);
259 
260   // Erases the encrypted shares for a client.
261   void EraseShareKeysForClient(uint32_t client_id);
262 
263   // Prepares MaskedInputCollectionRequest message to send to the client.
264   // This method will update fields in the request as needed, but will not clear
265   // any fields that are not specific to the share keys request for the specific
266   // client.  The caller can therefore set up a single ShareKeysRequest object,
267   // populate fields that will be common to all clients, and repeatedly call
268   // this method to set the client-specific fields before serializing the
269   // message and sending it.
270   void PrepareMaskedInputCollectionRequestForClient(
271       uint32_t client_id, MaskedInputCollectionRequest* request) const;
272 
273   // Clears all encrypted shares.
274   void ClearShareKeys();
275 
276   // ---------------------------------------------------------------------------
277   // Round 2 methods
278   // ---------------------------------------------------------------------------
279 
280   // Sets up the sum of encrypted vectors received by the clients in R1.  This
281   // must be called before any other R2 methods are called.
282   virtual std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>>
283   SetupMaskedInputCollection() = 0;
284 
285   // Finalizes the async aggregation of R2 messages before moving to R3.
286   virtual void FinalizeMaskedInputCollection() = 0;
287 
288   // Check that an encrypted vector received by the user is valid, and add it to
289   // the sum of encrypted vectors.
290   virtual Status HandleMaskedInputCollectionResponse(
291       std::unique_ptr<MaskedInputCollectionResponse> masked_input_response) = 0;
292 
293   // ---------------------------------------------------------------------------
294   // Round 3 methods
295   // ---------------------------------------------------------------------------
296 
297   // This must be called in the beginning of round 3 to setup Shamir shares
298   // tables based on client states at the beginning of the round.
299   void SetUpShamirSharesTables();
300 
301   // Populates Shamir shares tables with the data from UnmaskingResponse.
302   // Returning an error status means that the unmasking response was invalid.
303   Status HandleUnmaskingResponse(uint32_t client_id,
304                                  const UnmaskingResponse& unmasking_response);
305 
306   // ---------------------------------------------------------------------------
307   // PRNG computation methods
308   // ---------------------------------------------------------------------------
309 
310   // Result of performing Shamir secret sharing keys reconstruction.
311   struct ShamirReconstructionResult {
312     absl::flat_hash_map<uint32_t, EcdhKeyAgreement>
313         aborted_client_key_agreements;
314     absl::node_hash_map<uint32_t, AesKey> self_keys;
315   };
316 
317   // Performs reconstruction secret sharing keys reconstruction step of
318   // the PRNG stage of the protocol.
319   StatusOr<ShamirReconstructionResult> HandleShamirReconstruction();
320 
321   struct PrngWorkItems {
322     std::vector<AesKey> prng_keys_to_add;
323     std::vector<AesKey> prng_keys_to_subtract;
324   };
325 
326   // Initializes PRNG work items.
327   StatusOr<PrngWorkItems> InitializePrng(
328       const ShamirReconstructionResult& shamir_reconstruction_result) const;
329 
330   // Tells the PRNG stage of the protocol to start running asynchronously by
331   // executing PRNG work items.
332   // The returned cancellation token can be used to abort the asynchronous
333   // execution.
334   virtual CancellationToken StartPrng(
335       const PrngWorkItems& work_items,
336       std::function<void(Status)> done_callback) = 0;
337 
338  private:
339   std::unique_ptr<SecretSharingGraph> secret_sharing_graph_;
340   int minimum_number_of_clients_to_proceed_;
341 
342   std::vector<InputVectorSpecification> input_vector_specs_;
343   std::unique_ptr<SecAggServerMetricsListener> metrics_;
344   std::unique_ptr<AesPrngFactory> prng_factory_;
345   SendToClientsInterface* sender_;
346   std::unique_ptr<SecAggScheduler> scheduler_;
347 
348   std::unique_ptr<SecAggVectorMap> result_;
349 
350   size_t total_number_of_clients_;
351   std::vector<ClientStatus> client_statuses_;
352   std::unique_ptr<ExperimentsInterface> experiments_;
353 
354   // This vector collects the public keys sent by the clients that will be used
355   // for running the PRNG later on.
356   std::vector<EcdhPublicKey> pairwise_public_keys_;
357 
358   // This vector collects all pairs of public keys sent by the clients, so they
359   // can be forwarded at the end of Advertise Keys round.
360   std::vector<PairOfPublicKeys> pairs_of_public_keys_;
361 
362   std::unique_ptr<SessionId> session_id_;
363 
364   // Track the encrypted shares received from clients in preparation for sending
365   // them. encrypted_shares_table_[i][j] is an encryption of the pair of shares
366   // to be sent to client i, received from client j.
367   std::vector<std::vector<std::string>> encrypted_shares_;
368 
369   // Shamir shares tables.
370   // These store shares that have been collected from clients, and will be built
371   // up over the course of round 3. For both tables, the map key represents
372   // the client whose key these are shares of; the index in the vector
373   // represents the client who provided that key share.
374   std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>
375       pairwise_shamir_share_table_;
376   std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>
377       self_shamir_share_table_;
378 };
379 
380 }  // namespace secagg
381 }  // namespace fcp
382 
383 #endif  // FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_
384