xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server_protocol_impl.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_protocol_impl.h"
18 
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/node_hash_map.h"
23 #include "absl/status/status.h"
24 #include "absl/time/time.h"
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/server/tracing_schema.h"
27 #include "fcp/secagg/shared/compute_session_id.h"
28 #include "fcp/tracing/tracing_span.h"
29 
30 namespace {
31 
32 // Defines an experiments object with no experiments enabled
33 class EmptyExperiment : public fcp::secagg::ExperimentsInterface {
34  public:
IsEnabled(absl::string_view experiment_name)35   bool IsEnabled(absl::string_view experiment_name) override { return false; }
36 };
37 
38 }  // namespace
39 
40 namespace fcp {
41 namespace secagg {
42 
SecAggServerProtocolImpl(std::unique_ptr<SecretSharingGraph> graph,int minimum_number_of_clients_to_proceed,std::unique_ptr<SecAggServerMetricsListener> metrics,std::unique_ptr<AesPrngFactory> prng_factory,SendToClientsInterface * sender,std::unique_ptr<SecAggScheduler> scheduler,std::vector<ClientStatus> client_statuses,std::unique_ptr<ExperimentsInterface> experiments)43 SecAggServerProtocolImpl::SecAggServerProtocolImpl(
44     std::unique_ptr<SecretSharingGraph> graph,
45     int minimum_number_of_clients_to_proceed,
46     std::unique_ptr<SecAggServerMetricsListener> metrics,
47     std::unique_ptr<AesPrngFactory> prng_factory,
48     SendToClientsInterface* sender, std::unique_ptr<SecAggScheduler> scheduler,
49     std::vector<ClientStatus> client_statuses,
50     std::unique_ptr<ExperimentsInterface> experiments)
51     : secret_sharing_graph_(std::move(graph)),
52       minimum_number_of_clients_to_proceed_(
53           minimum_number_of_clients_to_proceed),
54       metrics_(std::move(metrics)),
55       prng_factory_(std::move(prng_factory)),
56       sender_(sender),
57       scheduler_(std::move(scheduler)),
58       total_number_of_clients_(client_statuses.size()),
59       client_statuses_(std::move(client_statuses)),
60       experiments_(experiments ? std::move(experiments)
61                                : std::unique_ptr<ExperimentsInterface>(
62                                      new EmptyExperiment())),
63       pairwise_public_keys_(total_number_of_clients()),
64       pairs_of_public_keys_(total_number_of_clients()),
65       encrypted_shares_(total_number_of_clients(),
66                         std::vector<std::string>(number_of_neighbors())) {}
67 
SetResult(std::unique_ptr<SecAggVectorMap> result)68 void SecAggServerProtocolImpl::SetResult(
69     std::unique_ptr<SecAggVectorMap> result) {
70   FCP_CHECK(!result_) << "Result can't be set twice";
71   result_ = std::move(result);
72 }
73 
TakeResult()74 std::unique_ptr<SecAggVectorMap> SecAggServerProtocolImpl::TakeResult() {
75   return std::move(result_);
76 }
77 
78 // -----------------------------------------------------------------------------
79 // Round 0 methods
80 // -----------------------------------------------------------------------------
81 
HandleAdvertiseKeys(uint32_t client_id,const AdvertiseKeys & advertise_keys)82 Status SecAggServerProtocolImpl::HandleAdvertiseKeys(
83     uint32_t client_id, const AdvertiseKeys& advertise_keys) {
84   const auto& pair_of_public_keys = advertise_keys.pair_of_public_keys();
85   if ((pair_of_public_keys.enc_pk().size() != EcdhPublicKey::kSize &&
86        (pair_of_public_keys.enc_pk().size() <
87             EcdhPublicKey::kUncompressedSize ||
88         pair_of_public_keys.noise_pk().size() <
89             EcdhPublicKey::kUncompressedSize)) ||
90       pair_of_public_keys.enc_pk().size() !=
91           pair_of_public_keys.noise_pk().size()) {
92     return ::absl::InvalidArgumentError(
93         "A public key sent by the client was not the correct size.");
94   }
95 
96   if (pair_of_public_keys.noise_pk().size() == EcdhPublicKey::kSize) {
97     pairwise_public_keys_[client_id] =
98         EcdhPublicKey(reinterpret_cast<const uint8_t*>(
99             pair_of_public_keys.noise_pk().c_str()));
100   } else {
101     // Strip off the header, if any, and use the uncompressed ECDH key.
102     size_t key_size_with_header = pair_of_public_keys.noise_pk().size();
103     pairwise_public_keys_[client_id] = EcdhPublicKey(
104         reinterpret_cast<const uint8_t*>(
105             pair_of_public_keys.noise_pk()
106                 .substr(key_size_with_header - EcdhPublicKey::kUncompressedSize)
107                 .c_str()),
108         EcdhPublicKey::kUncompressed);
109   }
110 
111   pairs_of_public_keys_[client_id] = pair_of_public_keys;
112   return ::absl::OkStatus();
113 }
114 
ErasePublicKeysForClient(uint32_t client_id)115 void SecAggServerProtocolImpl::ErasePublicKeysForClient(uint32_t client_id) {
116   pairwise_public_keys_[client_id] = EcdhPublicKey();
117   pairs_of_public_keys_[client_id] = PairOfPublicKeys();
118 }
119 
ComputeSessionId()120 void SecAggServerProtocolImpl::ComputeSessionId() {
121   // This message contains all keys, and is only built for the purpose
122   // of deriving the session key from it
123   ShareKeysRequest share_keys_request;
124   for (int i = 0; i < total_number_of_clients(); ++i) {
125     *(share_keys_request.add_pairs_of_public_keys()) = pairs_of_public_keys_[i];
126   }
127   set_session_id(std::make_unique<SessionId>(
128       fcp::secagg::ComputeSessionId(share_keys_request)));
129 }
130 
PrepareShareKeysRequestForClient(uint32_t client_id,ShareKeysRequest * request) const131 void SecAggServerProtocolImpl::PrepareShareKeysRequestForClient(
132     uint32_t client_id, ShareKeysRequest* request) const {
133   request->clear_pairs_of_public_keys();
134   for (int j = 0; j < secret_sharing_graph()->GetDegree(); ++j) {
135     *request->add_pairs_of_public_keys() =
136         pairs_of_public_keys_[secret_sharing_graph()->GetNeighbor(client_id,
137                                                                   j)];
138   }
139 }
140 
ClearPairsOfPublicKeys()141 void SecAggServerProtocolImpl::ClearPairsOfPublicKeys() {
142   pairs_of_public_keys_.clear();
143 }
144 
145 // -----------------------------------------------------------------------------
146 // Round 1 methods
147 // -----------------------------------------------------------------------------
148 
HandleShareKeysResponse(uint32_t client_id,const ShareKeysResponse & share_keys_response)149 Status SecAggServerProtocolImpl::HandleShareKeysResponse(
150     uint32_t client_id, const ShareKeysResponse& share_keys_response) {
151   // Verify that the message has the expected fields set before accepting it.
152   if (share_keys_response.encrypted_key_shares().size() !=
153       number_of_neighbors()) {
154     return ::absl::InvalidArgumentError(
155         "The ShareKeysResponse does not contain the expected number of "
156         "encrypted pairs of key shares.");
157   }
158 
159   for (uint32_t i = 0; i < number_of_neighbors(); ++i) {
160     bool i_is_empty = share_keys_response.encrypted_key_shares(i).empty();
161     int neighbor_id = GetNeighbor(client_id, i);
162     bool i_should_be_empty = (neighbor_id == client_id) ||
163                              (client_status(neighbor_id) ==
164                               ClientStatus::DEAD_BEFORE_SENDING_ANYTHING);
165     if (i_is_empty && !i_should_be_empty) {
166       return ::absl::InvalidArgumentError(
167           "Client omitted a key share that was expected.");
168     }
169     if (i_should_be_empty && !i_is_empty) {
170       return ::absl::InvalidArgumentError(
171           "Client sent a key share that was not expected.");
172     }
173   }
174 
175   // Client sent a valid message.
176   for (int i = 0; i < number_of_neighbors(); ++i) {
177     int neighbor_id = GetNeighbor(client_id, i);
178     // neighbor_id and client_id are neighbors, and thus index_in_neighbors is
179     // in [0, number_neighbors()-1]
180     int index_in_neighbor = GetNeighborIndexOrDie(neighbor_id, client_id);
181     encrypted_shares_[neighbor_id][index_in_neighbor] =
182         share_keys_response.encrypted_key_shares(i);
183   }
184 
185   return ::absl::OkStatus();
186 }
187 
EraseShareKeysForClient(uint32_t client_id)188 void SecAggServerProtocolImpl::EraseShareKeysForClient(uint32_t client_id) {
189   for (int i = 0; i < number_of_neighbors(); ++i) {
190     int neighbor_id = GetNeighbor(client_id, i);
191     int index_in_neighbor = GetNeighborIndexOrDie(neighbor_id, client_id);
192     encrypted_shares_[neighbor_id][index_in_neighbor].clear();
193   }
194 }
195 
PrepareMaskedInputCollectionRequestForClient(uint32_t client_id,MaskedInputCollectionRequest * request) const196 void SecAggServerProtocolImpl::PrepareMaskedInputCollectionRequestForClient(
197     uint32_t client_id, MaskedInputCollectionRequest* request) const {
198   request->clear_encrypted_key_shares();
199   for (int j = 0; j < number_of_neighbors(); ++j) {
200     request->add_encrypted_key_shares(encrypted_shares_[client_id][j]);
201   }
202 }
203 
ClearShareKeys()204 void SecAggServerProtocolImpl::ClearShareKeys() { encrypted_shares_.clear(); }
205 
206 // -----------------------------------------------------------------------------
207 // Round 3 methods
208 // -----------------------------------------------------------------------------
209 
210 // This enum and the following function relates the client status to whether
211 // or not its pairwise mask, its self mask, or neither will appear in the
212 // summed masked input.
213 enum class ClientMask { kPairwiseMask, kSelfMask, kNoMask };
214 
215 // Returns the type of mask the server expects to receive a share for, for a
216 // give client status.
ClientMaskType(const ClientStatus & client_status)217 inline ClientMask ClientMaskType(const ClientStatus& client_status) {
218   switch (client_status) {
219     case ClientStatus::SHARE_KEYS_RECEIVED:
220     case ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED:
221       return ClientMask::kPairwiseMask;
222       break;
223     case ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED:
224     case ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED:
225     case ClientStatus::UNMASKING_RESPONSE_RECEIVED:
226     case ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED:
227       return ClientMask::kSelfMask;
228       break;
229     case ClientStatus::READY_TO_START:
230     case ClientStatus::DEAD_BEFORE_SENDING_ANYTHING:
231     case ClientStatus::ADVERTISE_KEYS_RECEIVED:
232     case ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED:
233     default:
234       return ClientMask::kNoMask;
235   }
236 }
237 
SetUpShamirSharesTables()238 void SecAggServerProtocolImpl::SetUpShamirSharesTables() {
239   pairwise_shamir_share_table_ = std::make_unique<
240       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
241   self_shamir_share_table_ = std::make_unique<
242       absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
243 
244   // Prepare the share tables with rows for clients we expect to have shares for
245   for (uint32_t i = 0; i < total_number_of_clients(); ++i) {
246     auto mask_type = ClientMaskType(client_status(i));
247     if (mask_type == ClientMask::kPairwiseMask) {
248       pairwise_shamir_share_table_->emplace(i, number_of_neighbors());
249     } else if (mask_type == ClientMask::kSelfMask) {
250       self_shamir_share_table_->emplace(i, number_of_neighbors());
251     }
252   }
253 }
254 
HandleUnmaskingResponse(uint32_t client_id,const UnmaskingResponse & unmasking_response)255 Status SecAggServerProtocolImpl::HandleUnmaskingResponse(
256     uint32_t client_id, const UnmaskingResponse& unmasking_response) {
257   FCP_CHECK(pairwise_shamir_share_table_ != nullptr &&
258             self_shamir_share_table_ != nullptr)
259       << "Shamir Shares Tables haven't been initialized";
260 
261   // Verify the client sent all the right types of shares.
262   for (uint32_t i = 0; i < number_of_neighbors(); ++i) {
263     int ith_neighbor = GetNeighbor(client_id, i);
264     switch (ClientMaskType(client_status(ith_neighbor))) {
265       case ClientMask::kPairwiseMask:
266         if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() !=
267             NoiseOrPrfKeyShare::OneofSharesCase::kNoiseSkShare) {
268           return ::absl::InvalidArgumentError(
269               "Client did not include the correct type of key share.");
270         }
271         break;
272       case ClientMask::kSelfMask:
273         if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() !=
274             NoiseOrPrfKeyShare::OneofSharesCase::kPrfSkShare) {
275           return ::absl::InvalidArgumentError(
276               "Client did not include the correct type of key share.");
277         }
278         break;
279       case ClientMask::kNoMask:
280       default:
281         if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() !=
282             NoiseOrPrfKeyShare::OneofSharesCase::ONEOF_SHARES_NOT_SET) {
283           return ::absl::InvalidArgumentError(
284               "Client included a key share for which none was expected.");
285         }
286     }
287   }
288   // Prepare the received key shares for reconstruction by inserting them into
289   // the tables.
290   for (int i = 0; i < number_of_neighbors(); ++i) {
291     // Find the index of client_id in the list of neighbors of the ith
292     // neighbor of client_id
293     int ith_neighbor = GetNeighbor(client_id, i);
294     int index = GetNeighborIndexOrDie(ith_neighbor, client_id);
295     if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() ==
296         NoiseOrPrfKeyShare::OneofSharesCase::kNoiseSkShare) {
297       (*pairwise_shamir_share_table_)[ith_neighbor][index].data =
298           unmasking_response.noise_or_prf_key_shares(i).noise_sk_share();
299     } else if (unmasking_response.noise_or_prf_key_shares(i)
300                    .oneof_shares_case() ==
301                NoiseOrPrfKeyShare::OneofSharesCase::kPrfSkShare) {
302       (*self_shamir_share_table_)[ith_neighbor][index].data =
303           unmasking_response.noise_or_prf_key_shares(i).prf_sk_share();
304     }
305   }
306   return ::absl::OkStatus();
307 }
308 
309 // -----------------------------------------------------------------------------
310 // PRNG computation methods
311 // -----------------------------------------------------------------------------
312 
313 StatusOr<SecAggServerProtocolImpl::ShamirReconstructionResult>
HandleShamirReconstruction()314 SecAggServerProtocolImpl::HandleShamirReconstruction() {
315   FCP_CHECK(pairwise_shamir_share_table_ != nullptr &&
316             self_shamir_share_table_ != nullptr)
317       << "Shamir Shares Tables haven't been initialized";
318 
319   ShamirReconstructionResult result;
320   ShamirSecretSharing reconstructor;
321 
322   for (const auto& item : *pairwise_shamir_share_table_) {
323     FCP_ASSIGN_OR_RETURN(std::string reconstructed_key,
324                          reconstructor.Reconstruct(
325                              minimum_surviving_neighbors_for_reconstruction(),
326                              item.second, EcdhPrivateKey::kSize));
327     auto key_agreement = EcdhKeyAgreement::CreateFromPrivateKey(EcdhPrivateKey(
328         reinterpret_cast<const uint8_t*>(reconstructed_key.c_str())));
329     if (!key_agreement.ok()) {
330       // The server was unable to reconstruct the private key, probably
331       // because some client(s) sent invalid key shares. The only way out is
332       // to abort.
333       return ::absl::InvalidArgumentError(
334           "Unable to reconstruct aborted client's private key from shares");
335     }
336     result.aborted_client_key_agreements.try_emplace(
337         item.first, std::move(*(key_agreement.value())));
338   }
339 
340   for (const auto& item : *self_shamir_share_table_) {
341     FCP_ASSIGN_OR_RETURN(
342         AesKey reconstructed,
343         AesKey::CreateFromShares(
344             item.second, minimum_surviving_neighbors_for_reconstruction()));
345     result.self_keys.try_emplace(item.first, reconstructed);
346   }
347 
348   return std::move(result);
349 }
350 
351 StatusOr<SecAggServerProtocolImpl::PrngWorkItems>
InitializePrng(const ShamirReconstructionResult & shamir_reconstruction_result) const352 SecAggServerProtocolImpl::InitializePrng(
353     const ShamirReconstructionResult& shamir_reconstruction_result) const {
354   PrngWorkItems work_items;
355 
356   for (uint32_t i = 0; i < total_number_of_clients(); ++i) {
357     // Although clients who are DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED and
358     // kDeadAfterUnmaskingResponseReceived have they did so after sending
359     // their masked input. Therefore, it is possible to include their
360     // contribution to the aggregate sum. So we treat them here as if they had
361     // completed the protocol correctly.
362     auto status = client_status(i);
363     if (status != ClientStatus::UNMASKING_RESPONSE_RECEIVED &&
364         status != ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED &&
365         status != ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED) {
366       continue;
367     }
368 
369     // Since client i's value will be included in the sum, the server must
370     // remove its self mask.
371     auto it = shamir_reconstruction_result.self_keys.find(i);
372     FCP_CHECK(it != shamir_reconstruction_result.self_keys.end());
373     work_items.prng_keys_to_subtract.push_back(it->second);
374 
375     // For clients that aborted, client i's sum contains an un-canceled
376     // pairwise mask generated between the two clients. The server must remove
377     // this pairwise mask from the sum.
378     for (const auto& item :
379          shamir_reconstruction_result.aborted_client_key_agreements) {
380       if (!AreNeighbors(i, item.first)) {
381         continue;
382       }
383       auto shared_key =
384           item.second.ComputeSharedSecret(pairwise_public_keys(i));
385       if (!shared_key.ok()) {
386         // Should not happen; invalid public keys should already be detected.
387         // But if it does happen, abort.
388         return ::absl::InvalidArgumentError(
389             "Invalid public key from client detected");
390       }
391       if (IsOutgoingNeighbor(i, item.first)) {
392         work_items.prng_keys_to_add.push_back(shared_key.value());
393       } else {
394         work_items.prng_keys_to_subtract.push_back(shared_key.value());
395       }
396     }
397   }
398 
399   return std::move(work_items);
400 }
401 
402 }  // namespace secagg
403 }  // namespace fcp
404