1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_protocol_impl.h"
18
19 #include <string>
20 #include <utility>
21
22 #include "absl/container/node_hash_map.h"
23 #include "absl/status/status.h"
24 #include "absl/time/time.h"
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/server/tracing_schema.h"
27 #include "fcp/secagg/shared/compute_session_id.h"
28 #include "fcp/tracing/tracing_span.h"
29
30 namespace {
31
32 // Defines an experiments object with no experiments enabled
33 class EmptyExperiment : public fcp::secagg::ExperimentsInterface {
34 public:
IsEnabled(absl::string_view experiment_name)35 bool IsEnabled(absl::string_view experiment_name) override { return false; }
36 };
37
38 } // namespace
39
40 namespace fcp {
41 namespace secagg {
42
SecAggServerProtocolImpl(std::unique_ptr<SecretSharingGraph> graph,int minimum_number_of_clients_to_proceed,std::unique_ptr<SecAggServerMetricsListener> metrics,std::unique_ptr<AesPrngFactory> prng_factory,SendToClientsInterface * sender,std::unique_ptr<SecAggScheduler> scheduler,std::vector<ClientStatus> client_statuses,std::unique_ptr<ExperimentsInterface> experiments)43 SecAggServerProtocolImpl::SecAggServerProtocolImpl(
44 std::unique_ptr<SecretSharingGraph> graph,
45 int minimum_number_of_clients_to_proceed,
46 std::unique_ptr<SecAggServerMetricsListener> metrics,
47 std::unique_ptr<AesPrngFactory> prng_factory,
48 SendToClientsInterface* sender, std::unique_ptr<SecAggScheduler> scheduler,
49 std::vector<ClientStatus> client_statuses,
50 std::unique_ptr<ExperimentsInterface> experiments)
51 : secret_sharing_graph_(std::move(graph)),
52 minimum_number_of_clients_to_proceed_(
53 minimum_number_of_clients_to_proceed),
54 metrics_(std::move(metrics)),
55 prng_factory_(std::move(prng_factory)),
56 sender_(sender),
57 scheduler_(std::move(scheduler)),
58 total_number_of_clients_(client_statuses.size()),
59 client_statuses_(std::move(client_statuses)),
60 experiments_(experiments ? std::move(experiments)
61 : std::unique_ptr<ExperimentsInterface>(
62 new EmptyExperiment())),
63 pairwise_public_keys_(total_number_of_clients()),
64 pairs_of_public_keys_(total_number_of_clients()),
65 encrypted_shares_(total_number_of_clients(),
66 std::vector<std::string>(number_of_neighbors())) {}
67
SetResult(std::unique_ptr<SecAggVectorMap> result)68 void SecAggServerProtocolImpl::SetResult(
69 std::unique_ptr<SecAggVectorMap> result) {
70 FCP_CHECK(!result_) << "Result can't be set twice";
71 result_ = std::move(result);
72 }
73
TakeResult()74 std::unique_ptr<SecAggVectorMap> SecAggServerProtocolImpl::TakeResult() {
75 return std::move(result_);
76 }
77
78 // -----------------------------------------------------------------------------
79 // Round 0 methods
80 // -----------------------------------------------------------------------------
81
HandleAdvertiseKeys(uint32_t client_id,const AdvertiseKeys & advertise_keys)82 Status SecAggServerProtocolImpl::HandleAdvertiseKeys(
83 uint32_t client_id, const AdvertiseKeys& advertise_keys) {
84 const auto& pair_of_public_keys = advertise_keys.pair_of_public_keys();
85 if ((pair_of_public_keys.enc_pk().size() != EcdhPublicKey::kSize &&
86 (pair_of_public_keys.enc_pk().size() <
87 EcdhPublicKey::kUncompressedSize ||
88 pair_of_public_keys.noise_pk().size() <
89 EcdhPublicKey::kUncompressedSize)) ||
90 pair_of_public_keys.enc_pk().size() !=
91 pair_of_public_keys.noise_pk().size()) {
92 return ::absl::InvalidArgumentError(
93 "A public key sent by the client was not the correct size.");
94 }
95
96 if (pair_of_public_keys.noise_pk().size() == EcdhPublicKey::kSize) {
97 pairwise_public_keys_[client_id] =
98 EcdhPublicKey(reinterpret_cast<const uint8_t*>(
99 pair_of_public_keys.noise_pk().c_str()));
100 } else {
101 // Strip off the header, if any, and use the uncompressed ECDH key.
102 size_t key_size_with_header = pair_of_public_keys.noise_pk().size();
103 pairwise_public_keys_[client_id] = EcdhPublicKey(
104 reinterpret_cast<const uint8_t*>(
105 pair_of_public_keys.noise_pk()
106 .substr(key_size_with_header - EcdhPublicKey::kUncompressedSize)
107 .c_str()),
108 EcdhPublicKey::kUncompressed);
109 }
110
111 pairs_of_public_keys_[client_id] = pair_of_public_keys;
112 return ::absl::OkStatus();
113 }
114
ErasePublicKeysForClient(uint32_t client_id)115 void SecAggServerProtocolImpl::ErasePublicKeysForClient(uint32_t client_id) {
116 pairwise_public_keys_[client_id] = EcdhPublicKey();
117 pairs_of_public_keys_[client_id] = PairOfPublicKeys();
118 }
119
ComputeSessionId()120 void SecAggServerProtocolImpl::ComputeSessionId() {
121 // This message contains all keys, and is only built for the purpose
122 // of deriving the session key from it
123 ShareKeysRequest share_keys_request;
124 for (int i = 0; i < total_number_of_clients(); ++i) {
125 *(share_keys_request.add_pairs_of_public_keys()) = pairs_of_public_keys_[i];
126 }
127 set_session_id(std::make_unique<SessionId>(
128 fcp::secagg::ComputeSessionId(share_keys_request)));
129 }
130
PrepareShareKeysRequestForClient(uint32_t client_id,ShareKeysRequest * request) const131 void SecAggServerProtocolImpl::PrepareShareKeysRequestForClient(
132 uint32_t client_id, ShareKeysRequest* request) const {
133 request->clear_pairs_of_public_keys();
134 for (int j = 0; j < secret_sharing_graph()->GetDegree(); ++j) {
135 *request->add_pairs_of_public_keys() =
136 pairs_of_public_keys_[secret_sharing_graph()->GetNeighbor(client_id,
137 j)];
138 }
139 }
140
ClearPairsOfPublicKeys()141 void SecAggServerProtocolImpl::ClearPairsOfPublicKeys() {
142 pairs_of_public_keys_.clear();
143 }
144
145 // -----------------------------------------------------------------------------
146 // Round 1 methods
147 // -----------------------------------------------------------------------------
148
HandleShareKeysResponse(uint32_t client_id,const ShareKeysResponse & share_keys_response)149 Status SecAggServerProtocolImpl::HandleShareKeysResponse(
150 uint32_t client_id, const ShareKeysResponse& share_keys_response) {
151 // Verify that the message has the expected fields set before accepting it.
152 if (share_keys_response.encrypted_key_shares().size() !=
153 number_of_neighbors()) {
154 return ::absl::InvalidArgumentError(
155 "The ShareKeysResponse does not contain the expected number of "
156 "encrypted pairs of key shares.");
157 }
158
159 for (uint32_t i = 0; i < number_of_neighbors(); ++i) {
160 bool i_is_empty = share_keys_response.encrypted_key_shares(i).empty();
161 int neighbor_id = GetNeighbor(client_id, i);
162 bool i_should_be_empty = (neighbor_id == client_id) ||
163 (client_status(neighbor_id) ==
164 ClientStatus::DEAD_BEFORE_SENDING_ANYTHING);
165 if (i_is_empty && !i_should_be_empty) {
166 return ::absl::InvalidArgumentError(
167 "Client omitted a key share that was expected.");
168 }
169 if (i_should_be_empty && !i_is_empty) {
170 return ::absl::InvalidArgumentError(
171 "Client sent a key share that was not expected.");
172 }
173 }
174
175 // Client sent a valid message.
176 for (int i = 0; i < number_of_neighbors(); ++i) {
177 int neighbor_id = GetNeighbor(client_id, i);
178 // neighbor_id and client_id are neighbors, and thus index_in_neighbors is
179 // in [0, number_neighbors()-1]
180 int index_in_neighbor = GetNeighborIndexOrDie(neighbor_id, client_id);
181 encrypted_shares_[neighbor_id][index_in_neighbor] =
182 share_keys_response.encrypted_key_shares(i);
183 }
184
185 return ::absl::OkStatus();
186 }
187
EraseShareKeysForClient(uint32_t client_id)188 void SecAggServerProtocolImpl::EraseShareKeysForClient(uint32_t client_id) {
189 for (int i = 0; i < number_of_neighbors(); ++i) {
190 int neighbor_id = GetNeighbor(client_id, i);
191 int index_in_neighbor = GetNeighborIndexOrDie(neighbor_id, client_id);
192 encrypted_shares_[neighbor_id][index_in_neighbor].clear();
193 }
194 }
195
PrepareMaskedInputCollectionRequestForClient(uint32_t client_id,MaskedInputCollectionRequest * request) const196 void SecAggServerProtocolImpl::PrepareMaskedInputCollectionRequestForClient(
197 uint32_t client_id, MaskedInputCollectionRequest* request) const {
198 request->clear_encrypted_key_shares();
199 for (int j = 0; j < number_of_neighbors(); ++j) {
200 request->add_encrypted_key_shares(encrypted_shares_[client_id][j]);
201 }
202 }
203
ClearShareKeys()204 void SecAggServerProtocolImpl::ClearShareKeys() { encrypted_shares_.clear(); }
205
206 // -----------------------------------------------------------------------------
207 // Round 3 methods
208 // -----------------------------------------------------------------------------
209
210 // This enum and the following function relates the client status to whether
211 // or not its pairwise mask, its self mask, or neither will appear in the
212 // summed masked input.
213 enum class ClientMask { kPairwiseMask, kSelfMask, kNoMask };
214
215 // Returns the type of mask the server expects to receive a share for, for a
216 // give client status.
ClientMaskType(const ClientStatus & client_status)217 inline ClientMask ClientMaskType(const ClientStatus& client_status) {
218 switch (client_status) {
219 case ClientStatus::SHARE_KEYS_RECEIVED:
220 case ClientStatus::DEAD_AFTER_SHARE_KEYS_RECEIVED:
221 return ClientMask::kPairwiseMask;
222 break;
223 case ClientStatus::MASKED_INPUT_RESPONSE_RECEIVED:
224 case ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED:
225 case ClientStatus::UNMASKING_RESPONSE_RECEIVED:
226 case ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED:
227 return ClientMask::kSelfMask;
228 break;
229 case ClientStatus::READY_TO_START:
230 case ClientStatus::DEAD_BEFORE_SENDING_ANYTHING:
231 case ClientStatus::ADVERTISE_KEYS_RECEIVED:
232 case ClientStatus::DEAD_AFTER_ADVERTISE_KEYS_RECEIVED:
233 default:
234 return ClientMask::kNoMask;
235 }
236 }
237
SetUpShamirSharesTables()238 void SecAggServerProtocolImpl::SetUpShamirSharesTables() {
239 pairwise_shamir_share_table_ = std::make_unique<
240 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
241 self_shamir_share_table_ = std::make_unique<
242 absl::flat_hash_map<uint32_t, std::vector<ShamirShare>>>();
243
244 // Prepare the share tables with rows for clients we expect to have shares for
245 for (uint32_t i = 0; i < total_number_of_clients(); ++i) {
246 auto mask_type = ClientMaskType(client_status(i));
247 if (mask_type == ClientMask::kPairwiseMask) {
248 pairwise_shamir_share_table_->emplace(i, number_of_neighbors());
249 } else if (mask_type == ClientMask::kSelfMask) {
250 self_shamir_share_table_->emplace(i, number_of_neighbors());
251 }
252 }
253 }
254
HandleUnmaskingResponse(uint32_t client_id,const UnmaskingResponse & unmasking_response)255 Status SecAggServerProtocolImpl::HandleUnmaskingResponse(
256 uint32_t client_id, const UnmaskingResponse& unmasking_response) {
257 FCP_CHECK(pairwise_shamir_share_table_ != nullptr &&
258 self_shamir_share_table_ != nullptr)
259 << "Shamir Shares Tables haven't been initialized";
260
261 // Verify the client sent all the right types of shares.
262 for (uint32_t i = 0; i < number_of_neighbors(); ++i) {
263 int ith_neighbor = GetNeighbor(client_id, i);
264 switch (ClientMaskType(client_status(ith_neighbor))) {
265 case ClientMask::kPairwiseMask:
266 if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() !=
267 NoiseOrPrfKeyShare::OneofSharesCase::kNoiseSkShare) {
268 return ::absl::InvalidArgumentError(
269 "Client did not include the correct type of key share.");
270 }
271 break;
272 case ClientMask::kSelfMask:
273 if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() !=
274 NoiseOrPrfKeyShare::OneofSharesCase::kPrfSkShare) {
275 return ::absl::InvalidArgumentError(
276 "Client did not include the correct type of key share.");
277 }
278 break;
279 case ClientMask::kNoMask:
280 default:
281 if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() !=
282 NoiseOrPrfKeyShare::OneofSharesCase::ONEOF_SHARES_NOT_SET) {
283 return ::absl::InvalidArgumentError(
284 "Client included a key share for which none was expected.");
285 }
286 }
287 }
288 // Prepare the received key shares for reconstruction by inserting them into
289 // the tables.
290 for (int i = 0; i < number_of_neighbors(); ++i) {
291 // Find the index of client_id in the list of neighbors of the ith
292 // neighbor of client_id
293 int ith_neighbor = GetNeighbor(client_id, i);
294 int index = GetNeighborIndexOrDie(ith_neighbor, client_id);
295 if (unmasking_response.noise_or_prf_key_shares(i).oneof_shares_case() ==
296 NoiseOrPrfKeyShare::OneofSharesCase::kNoiseSkShare) {
297 (*pairwise_shamir_share_table_)[ith_neighbor][index].data =
298 unmasking_response.noise_or_prf_key_shares(i).noise_sk_share();
299 } else if (unmasking_response.noise_or_prf_key_shares(i)
300 .oneof_shares_case() ==
301 NoiseOrPrfKeyShare::OneofSharesCase::kPrfSkShare) {
302 (*self_shamir_share_table_)[ith_neighbor][index].data =
303 unmasking_response.noise_or_prf_key_shares(i).prf_sk_share();
304 }
305 }
306 return ::absl::OkStatus();
307 }
308
309 // -----------------------------------------------------------------------------
310 // PRNG computation methods
311 // -----------------------------------------------------------------------------
312
313 StatusOr<SecAggServerProtocolImpl::ShamirReconstructionResult>
HandleShamirReconstruction()314 SecAggServerProtocolImpl::HandleShamirReconstruction() {
315 FCP_CHECK(pairwise_shamir_share_table_ != nullptr &&
316 self_shamir_share_table_ != nullptr)
317 << "Shamir Shares Tables haven't been initialized";
318
319 ShamirReconstructionResult result;
320 ShamirSecretSharing reconstructor;
321
322 for (const auto& item : *pairwise_shamir_share_table_) {
323 FCP_ASSIGN_OR_RETURN(std::string reconstructed_key,
324 reconstructor.Reconstruct(
325 minimum_surviving_neighbors_for_reconstruction(),
326 item.second, EcdhPrivateKey::kSize));
327 auto key_agreement = EcdhKeyAgreement::CreateFromPrivateKey(EcdhPrivateKey(
328 reinterpret_cast<const uint8_t*>(reconstructed_key.c_str())));
329 if (!key_agreement.ok()) {
330 // The server was unable to reconstruct the private key, probably
331 // because some client(s) sent invalid key shares. The only way out is
332 // to abort.
333 return ::absl::InvalidArgumentError(
334 "Unable to reconstruct aborted client's private key from shares");
335 }
336 result.aborted_client_key_agreements.try_emplace(
337 item.first, std::move(*(key_agreement.value())));
338 }
339
340 for (const auto& item : *self_shamir_share_table_) {
341 FCP_ASSIGN_OR_RETURN(
342 AesKey reconstructed,
343 AesKey::CreateFromShares(
344 item.second, minimum_surviving_neighbors_for_reconstruction()));
345 result.self_keys.try_emplace(item.first, reconstructed);
346 }
347
348 return std::move(result);
349 }
350
351 StatusOr<SecAggServerProtocolImpl::PrngWorkItems>
InitializePrng(const ShamirReconstructionResult & shamir_reconstruction_result) const352 SecAggServerProtocolImpl::InitializePrng(
353 const ShamirReconstructionResult& shamir_reconstruction_result) const {
354 PrngWorkItems work_items;
355
356 for (uint32_t i = 0; i < total_number_of_clients(); ++i) {
357 // Although clients who are DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED and
358 // kDeadAfterUnmaskingResponseReceived have they did so after sending
359 // their masked input. Therefore, it is possible to include their
360 // contribution to the aggregate sum. So we treat them here as if they had
361 // completed the protocol correctly.
362 auto status = client_status(i);
363 if (status != ClientStatus::UNMASKING_RESPONSE_RECEIVED &&
364 status != ClientStatus::DEAD_AFTER_UNMASKING_RESPONSE_RECEIVED &&
365 status != ClientStatus::DEAD_AFTER_MASKED_INPUT_RESPONSE_RECEIVED) {
366 continue;
367 }
368
369 // Since client i's value will be included in the sum, the server must
370 // remove its self mask.
371 auto it = shamir_reconstruction_result.self_keys.find(i);
372 FCP_CHECK(it != shamir_reconstruction_result.self_keys.end());
373 work_items.prng_keys_to_subtract.push_back(it->second);
374
375 // For clients that aborted, client i's sum contains an un-canceled
376 // pairwise mask generated between the two clients. The server must remove
377 // this pairwise mask from the sum.
378 for (const auto& item :
379 shamir_reconstruction_result.aborted_client_key_agreements) {
380 if (!AreNeighbors(i, item.first)) {
381 continue;
382 }
383 auto shared_key =
384 item.second.ComputeSharedSecret(pairwise_public_keys(i));
385 if (!shared_key.ok()) {
386 // Should not happen; invalid public keys should already be detected.
387 // But if it does happen, abort.
388 return ::absl::InvalidArgumentError(
389 "Invalid public key from client detected");
390 }
391 if (IsOutgoingNeighbor(i, item.first)) {
392 work_items.prng_keys_to_add.push_back(shared_key.value());
393 } else {
394 work_items.prng_keys_to_subtract.push_back(shared_key.value());
395 }
396 }
397 }
398
399 return std::move(work_items);
400 }
401
402 } // namespace secagg
403 } // namespace fcp
404