1 /* 2 * Copyright 2020 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_ 18 #define FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_ 19 20 #include <functional> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "absl/container/flat_hash_map.h" 27 #include "fcp/secagg/server/experiments_interface.h" 28 #include "fcp/secagg/server/secagg_scheduler.h" 29 #include "fcp/secagg/server/secagg_server_enums.pb.h" 30 #include "fcp/secagg/server/secagg_server_metrics_listener.h" 31 #include "fcp/secagg/server/secret_sharing_graph.h" 32 #include "fcp/secagg/server/send_to_clients_interface.h" 33 #include "fcp/secagg/shared/aes_prng_factory.h" 34 #include "fcp/secagg/shared/compute_session_id.h" 35 #include "fcp/secagg/shared/ecdh_key_agreement.h" 36 #include "fcp/secagg/shared/ecdh_keys.h" 37 #include "fcp/secagg/shared/input_vector_specification.h" 38 #include "fcp/secagg/shared/secagg_messages.pb.h" 39 #include "fcp/secagg/shared/secagg_vector.h" 40 #include "fcp/secagg/shared/shamir_secret_sharing.h" 41 42 namespace fcp { 43 namespace secagg { 44 45 // Interface that describes internal implementation of SecAgg protocol. 46 // 47 // The general design is the following 48 // 49 // +--------------+ +-------------------+ +--------------------------+ 50 // | SecAggServer |--->| SecAggServerState |--->| SecAggServerProtocolImpl | 51 // +--------------+ +-------------------+ +--------------------------+ 52 // ^ ^ 53 // /-\ /-\ 54 // | | 55 // +-------------------+ +------------------------+ 56 // | Specific State | | Specific protocol impl | 57 // +-------------------+ +------------------------+ 58 // 59 // Specific states implement logic specific to each logic SecAgg state, such as 60 // R0AdvertiseKeys or PrngRunning, while specific protocol implementation is 61 // shared between all states and is responsible for encapsulating the data 62 // for the protocol and providing methods for manipulating the data. 63 // 64 65 class SecAggServerProtocolImpl { 66 public: 67 explicit SecAggServerProtocolImpl( 68 std::unique_ptr<SecretSharingGraph> graph, 69 int minimum_number_of_clients_to_proceed, 70 std::unique_ptr<SecAggServerMetricsListener> metrics, 71 std::unique_ptr<AesPrngFactory> prng_factory, 72 SendToClientsInterface* sender, 73 std::unique_ptr<SecAggScheduler> scheduler, 74 std::vector<ClientStatus> client_statuses, 75 std::unique_ptr<ExperimentsInterface> experiments = nullptr); 76 virtual ~SecAggServerProtocolImpl() = default; 77 78 SecAggServerProtocolImpl(const SecAggServerProtocolImpl& other) = delete; 79 SecAggServerProtocolImpl& operator=(const SecAggServerProtocolImpl& other) = 80 delete; 81 82 // Returns server variant for this protocol implementation. 83 virtual ServerVariant server_variant() const = 0; 84 85 // Returns the graph that represents the cohort of clients. secret_sharing_graph()86 inline const SecretSharingGraph* secret_sharing_graph() const { 87 return secret_sharing_graph_.get(); 88 } 89 90 // Returns the minimum threshold number of clients that need to send valid 91 // responses in order for the protocol to proceed from one round to the next. minimum_number_of_clients_to_proceed()92 inline int minimum_number_of_clients_to_proceed() const { 93 return minimum_number_of_clients_to_proceed_; 94 } 95 96 // Returns the callback interface for recording metrics. metrics()97 inline SecAggServerMetricsListener* metrics() const { return metrics_.get(); } 98 99 // Returns a reference to an instance of a subclass of AesPrngFactory. prng_factory()100 inline AesPrngFactory* prng_factory() const { return prng_factory_.get(); } 101 102 // Returns the callback interface for sending protocol buffer messages to the 103 // client. sender()104 inline SendToClientsInterface* sender() const { return sender_; } 105 106 // Returns the scheduler for scheduling parallel computation tasks and 107 // callbacks. scheduler()108 inline SecAggScheduler* scheduler() const { return scheduler_.get(); } 109 110 // Returns the experiments experiments()111 inline ExperimentsInterface* experiments() const { 112 return experiments_.get(); 113 } 114 115 // Getting or setting the protocol result. 116 // 117 // TODO(team): SetResult should not be needed (except for testing) once 118 // PRNG computation is moved into the protocol implementation. 119 void SetResult(std::unique_ptr<SecAggVectorMap> result); 120 std::unique_ptr<SecAggVectorMap> TakeResult(); 121 122 // Gets the client status. client_status(uint32_t client_id)123 inline const ClientStatus& client_status(uint32_t client_id) const { 124 return client_statuses_.at(client_id); 125 } 126 127 // Sets the client status. set_client_status(uint32_t client_id,ClientStatus status)128 inline void set_client_status(uint32_t client_id, ClientStatus status) { 129 client_statuses_[client_id] = status; 130 } 131 132 // Gets the number of clients that the protocol starts with. total_number_of_clients()133 inline size_t total_number_of_clients() const { 134 return total_number_of_clients_; 135 } 136 137 // Returns the number of neighbors of each client. number_of_neighbors()138 inline const int number_of_neighbors() const { 139 return secret_sharing_graph()->GetDegree(); 140 } 141 142 // Returns the minimum number of neighbors of a client that must not drop-out 143 // for that client's contribution to be included in the sum. This corresponds 144 // to the threshold in the shamir secret sharing of self and pairwise masks. minimum_surviving_neighbors_for_reconstruction()145 inline const int minimum_surviving_neighbors_for_reconstruction() const { 146 return secret_sharing_graph()->GetThreshold(); 147 } 148 149 // Returns client_id's ith neighbor. 150 // This function assumes that 0 <= i < number_of_neighbors() and will throw a 151 // runtime error if that's not the case GetNeighbor(int client_id,int i)152 inline const int GetNeighbor(int client_id, int i) const { 153 return secret_sharing_graph()->GetNeighbor(client_id, i); 154 } 155 156 // Returns the index of client_id_2 in the list of neighbors of client_id_1, 157 // if present GetNeighborIndex(int client_id_1,int client_id_2)158 inline const std::optional<int> GetNeighborIndex(int client_id_1, 159 int client_id_2) const { 160 return secret_sharing_graph()->GetNeighborIndex(client_id_1, client_id_2); 161 } 162 163 // Returns the index of client_id_2 in the list of neighbors of client_id_1 164 // This function assumes that client_id_1 and client_id_2 are neighbors, and 165 // will throw a runtime error if that's not the case GetNeighborIndexOrDie(int client_id_1,int client_id_2)166 inline const int GetNeighborIndexOrDie(int client_id_1, 167 int client_id_2) const { 168 auto index = 169 secret_sharing_graph()->GetNeighborIndex(client_id_1, client_id_2); 170 FCP_CHECK(index.has_value()); 171 return index.value(); 172 } 173 174 // Returns true if clients client_id_1 and client_id_1 are neighbors, else 175 // false. AreNeighbors(int client_id_1,int client_id_2)176 inline const bool AreNeighbors(int client_id_1, int client_id_2) const { 177 return secret_sharing_graph()->AreNeighbors(client_id_1, client_id_2); 178 } 179 180 // Returns true if client_id_1 is an outgoing neighbor of client_id_2, else 181 // false. IsOutgoingNeighbor(int client_id_1,int client_id_2)182 inline const bool IsOutgoingNeighbor(int client_id_1, int client_id_2) const { 183 return secret_sharing_graph()->IsOutgoingNeighbor(client_id_1, client_id_2); 184 } 185 SetPairwisePublicKeys(uint32_t client_id,const EcdhPublicKey & pairwise_key)186 inline void SetPairwisePublicKeys(uint32_t client_id, 187 const EcdhPublicKey& pairwise_key) { 188 pairwise_public_keys_[client_id] = pairwise_key; 189 } 190 pairwise_public_keys(uint32_t client_id)191 inline const EcdhPublicKey& pairwise_public_keys(uint32_t client_id) const { 192 return pairwise_public_keys_[client_id]; 193 } 194 session_id()195 inline const SessionId& session_id() const { 196 FCP_CHECK(session_id_ != nullptr); 197 return *session_id_; 198 } 199 set_session_id(std::unique_ptr<SessionId> session_id)200 void set_session_id(std::unique_ptr<SessionId> session_id) { 201 FCP_CHECK(session_id != nullptr); 202 session_id_ = std::move(session_id); 203 } 204 205 // TODO(team): Review whether getters and setters below are needed. 206 // Most of these fields are needed only for testing. 207 set_pairwise_shamir_share_table(std::unique_ptr<absl::flat_hash_map<uint32_t,std::vector<ShamirShare>>> pairwise_shamir_share_table)208 void set_pairwise_shamir_share_table( 209 std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>> 210 pairwise_shamir_share_table) { 211 pairwise_shamir_share_table_ = std::move(pairwise_shamir_share_table); 212 } 213 set_self_shamir_share_table(std::unique_ptr<absl::flat_hash_map<uint32_t,std::vector<ShamirShare>>> self_shamir_share_table)214 void set_self_shamir_share_table( 215 std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>> 216 self_shamir_share_table) { 217 self_shamir_share_table_ = std::move(self_shamir_share_table); 218 } 219 220 // --------------------------------------------------------------------------- 221 // Round 0 methods 222 // --------------------------------------------------------------------------- 223 224 // Sets the public key pairs for a client. 225 Status HandleAdvertiseKeys(uint32_t client_id, 226 const AdvertiseKeys& advertise_keys); 227 228 // Erases public key pairs for a client. 229 void ErasePublicKeysForClient(uint32_t client_id); 230 231 // Compute session ID based on public key pairs advertised by clients. 232 void ComputeSessionId(); 233 234 // This method allows a protocol implementation to populate fields that are 235 // common to the ShareKeysRequest sent to all clients. 236 virtual Status InitializeShareKeysRequest( 237 ShareKeysRequest* request) const = 0; 238 239 // Prepares ShareKeysRequest message to send to the client. 240 // This method will update fields in the request as needed, but will not clear 241 // any fields that are not specific to the share keys request for the specific 242 // client. The caller can therefore set up a single ShareKeysRequest object, 243 // populate fields that will be common to all clients, and repeatedly call 244 // this method to set the client-specific fields before serializing the 245 // message and sending it. 246 void PrepareShareKeysRequestForClient(uint32_t client_id, 247 ShareKeysRequest* request) const; 248 249 // Clears all pairs of public keys. 250 void ClearPairsOfPublicKeys(); 251 252 // --------------------------------------------------------------------------- 253 // Round 1 methods 254 // --------------------------------------------------------------------------- 255 256 // Sets the encrypted shares received from a client. 257 Status HandleShareKeysResponse(uint32_t client_id, 258 const ShareKeysResponse& share_keys_response); 259 260 // Erases the encrypted shares for a client. 261 void EraseShareKeysForClient(uint32_t client_id); 262 263 // Prepares MaskedInputCollectionRequest message to send to the client. 264 // This method will update fields in the request as needed, but will not clear 265 // any fields that are not specific to the share keys request for the specific 266 // client. The caller can therefore set up a single ShareKeysRequest object, 267 // populate fields that will be common to all clients, and repeatedly call 268 // this method to set the client-specific fields before serializing the 269 // message and sending it. 270 void PrepareMaskedInputCollectionRequestForClient( 271 uint32_t client_id, MaskedInputCollectionRequest* request) const; 272 273 // Clears all encrypted shares. 274 void ClearShareKeys(); 275 276 // --------------------------------------------------------------------------- 277 // Round 2 methods 278 // --------------------------------------------------------------------------- 279 280 // Sets up the sum of encrypted vectors received by the clients in R1. This 281 // must be called before any other R2 methods are called. 282 virtual std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>> 283 SetupMaskedInputCollection() = 0; 284 285 // Finalizes the async aggregation of R2 messages before moving to R3. 286 virtual void FinalizeMaskedInputCollection() = 0; 287 288 // Check that an encrypted vector received by the user is valid, and add it to 289 // the sum of encrypted vectors. 290 virtual Status HandleMaskedInputCollectionResponse( 291 std::unique_ptr<MaskedInputCollectionResponse> masked_input_response) = 0; 292 293 // --------------------------------------------------------------------------- 294 // Round 3 methods 295 // --------------------------------------------------------------------------- 296 297 // This must be called in the beginning of round 3 to setup Shamir shares 298 // tables based on client states at the beginning of the round. 299 void SetUpShamirSharesTables(); 300 301 // Populates Shamir shares tables with the data from UnmaskingResponse. 302 // Returning an error status means that the unmasking response was invalid. 303 Status HandleUnmaskingResponse(uint32_t client_id, 304 const UnmaskingResponse& unmasking_response); 305 306 // --------------------------------------------------------------------------- 307 // PRNG computation methods 308 // --------------------------------------------------------------------------- 309 310 // Result of performing Shamir secret sharing keys reconstruction. 311 struct ShamirReconstructionResult { 312 absl::flat_hash_map<uint32_t, EcdhKeyAgreement> 313 aborted_client_key_agreements; 314 absl::node_hash_map<uint32_t, AesKey> self_keys; 315 }; 316 317 // Performs reconstruction secret sharing keys reconstruction step of 318 // the PRNG stage of the protocol. 319 StatusOr<ShamirReconstructionResult> HandleShamirReconstruction(); 320 321 struct PrngWorkItems { 322 std::vector<AesKey> prng_keys_to_add; 323 std::vector<AesKey> prng_keys_to_subtract; 324 }; 325 326 // Initializes PRNG work items. 327 StatusOr<PrngWorkItems> InitializePrng( 328 const ShamirReconstructionResult& shamir_reconstruction_result) const; 329 330 // Tells the PRNG stage of the protocol to start running asynchronously by 331 // executing PRNG work items. 332 // The returned cancellation token can be used to abort the asynchronous 333 // execution. 334 virtual CancellationToken StartPrng( 335 const PrngWorkItems& work_items, 336 std::function<void(Status)> done_callback) = 0; 337 338 private: 339 std::unique_ptr<SecretSharingGraph> secret_sharing_graph_; 340 int minimum_number_of_clients_to_proceed_; 341 342 std::vector<InputVectorSpecification> input_vector_specs_; 343 std::unique_ptr<SecAggServerMetricsListener> metrics_; 344 std::unique_ptr<AesPrngFactory> prng_factory_; 345 SendToClientsInterface* sender_; 346 std::unique_ptr<SecAggScheduler> scheduler_; 347 348 std::unique_ptr<SecAggVectorMap> result_; 349 350 size_t total_number_of_clients_; 351 std::vector<ClientStatus> client_statuses_; 352 std::unique_ptr<ExperimentsInterface> experiments_; 353 354 // This vector collects the public keys sent by the clients that will be used 355 // for running the PRNG later on. 356 std::vector<EcdhPublicKey> pairwise_public_keys_; 357 358 // This vector collects all pairs of public keys sent by the clients, so they 359 // can be forwarded at the end of Advertise Keys round. 360 std::vector<PairOfPublicKeys> pairs_of_public_keys_; 361 362 std::unique_ptr<SessionId> session_id_; 363 364 // Track the encrypted shares received from clients in preparation for sending 365 // them. encrypted_shares_table_[i][j] is an encryption of the pair of shares 366 // to be sent to client i, received from client j. 367 std::vector<std::vector<std::string>> encrypted_shares_; 368 369 // Shamir shares tables. 370 // These store shares that have been collected from clients, and will be built 371 // up over the course of round 3. For both tables, the map key represents 372 // the client whose key these are shares of; the index in the vector 373 // represents the client who provided that key share. 374 std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>> 375 pairwise_shamir_share_table_; 376 std::unique_ptr<absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>> 377 self_shamir_share_table_; 378 }; 379 380 } // namespace secagg 381 } // namespace fcp 382 383 #endif // FCP_SECAGG_SERVER_SECAGG_SERVER_PROTOCOL_IMPL_H_ 384