xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/aes/aes_secagg_server_protocol_impl.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
18 
19 #include <algorithm>
20 #include <cstddef>
21 #include <functional>
22 #include <iterator>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/container/node_hash_map.h"
29 #include "absl/status/status.h"
30 #include "fcp/base/monitoring.h"
31 #include "fcp/secagg/server/experiments_names.h"
32 #include "fcp/secagg/server/secagg_scheduler.h"
33 #include "fcp/secagg/shared/map_of_masks.h"
34 #include "fcp/secagg/shared/math.h"
35 #include "fcp/secagg/shared/secagg_vector.h"
36 
37 namespace {
38 
AddReduce(std::vector<std::unique_ptr<fcp::secagg::SecAggVectorMap>> vector_of_maps)39 std::unique_ptr<fcp::secagg::SecAggUnpackedVectorMap> AddReduce(
40     std::vector<std::unique_ptr<fcp::secagg::SecAggVectorMap>> vector_of_maps) {
41   FCP_CHECK(!vector_of_maps.empty());
42   // Initialize result
43   auto result = std::make_unique<fcp::secagg::SecAggUnpackedVectorMap>(
44       *vector_of_maps[0]);
45   // Reduce vector of maps
46   for (int i = 1; i < vector_of_maps.size(); ++i) {
47     result->Add(*vector_of_maps[i]);
48   }
49   return result;
50 }
51 
52 // Initializes a SecAggUnpackedVectorMap object according to a provided input
53 // vector specification
InitializeVectorMap(const std::vector<fcp::secagg::InputVectorSpecification> & input_vector_specs)54 std::unique_ptr<fcp::secagg::SecAggUnpackedVectorMap> InitializeVectorMap(
55     const std::vector<fcp::secagg::InputVectorSpecification>&
56         input_vector_specs) {
57   auto vector_map = std::make_unique<fcp::secagg::SecAggUnpackedVectorMap>();
58   for (const fcp::secagg::InputVectorSpecification& vector_spec :
59        input_vector_specs) {
60     vector_map->emplace(vector_spec.name(),
61                         fcp::secagg::SecAggUnpackedVector(
62                             vector_spec.length(), vector_spec.modulus()));
63   }
64   return vector_map;
65 }
66 
67 }  // namespace
68 
69 namespace fcp {
70 namespace secagg {
71 
72 // The number of keys included in a single PRNG job.
73 static constexpr int kPrngBatchSize = 32;
74 
75 std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>>
SetupMaskedInputCollection()76 AesSecAggServerProtocolImpl::SetupMaskedInputCollection() {
77   if (!experiments()->IsEnabled(kSecAggAsyncRound2Experiment)) {
78     // Prepare the sum of masked input vectors with all zeroes.
79     masked_input_ = InitializeVectorMap(input_vector_specs());
80   } else {
81     auto initial_value = InitializeVectorMap(input_vector_specs());
82     masked_input_accumulator_ =
83         scheduler()->CreateAccumulator<SecAggUnpackedVectorMap>(
84             std::move(initial_value), SecAggUnpackedVectorMap::AddMaps);
85   }
86   return masked_input_accumulator_;
87 }
88 
89 std::vector<std::unique_ptr<SecAggVectorMap>>
TakeMaskedInputQueue()90 AesSecAggServerProtocolImpl::TakeMaskedInputQueue() {
91   absl::MutexLock lock(&mutex_);
92   return std::move(masked_input_queue_);
93 }
94 
HandleMaskedInputCollectionResponse(std::unique_ptr<MaskedInputCollectionResponse> masked_input_response)95 Status AesSecAggServerProtocolImpl::HandleMaskedInputCollectionResponse(
96     std::unique_ptr<MaskedInputCollectionResponse> masked_input_response) {
97   FCP_CHECK(masked_input_response);
98   // Make sure the received vectors match the specification.
99   if (masked_input_response->vectors().size() != input_vector_specs().size()) {
100     return ::absl::InvalidArgumentError(
101         "Masked input does not match input vector specification - "
102         "wrong number of vectors.");
103   }
104   auto& input_vectors = *masked_input_response->mutable_vectors();
105   auto checked_masked_vectors = std::make_unique<SecAggVectorMap>();
106   for (const InputVectorSpecification& vector_spec : input_vector_specs()) {
107     auto masked_vector = input_vectors.find(vector_spec.name());
108     if (masked_vector == input_vectors.end()) {
109       return ::absl::InvalidArgumentError(
110           "Masked input does not match input vector specification - wrong "
111           "vector names.");
112     }
113     // TODO(team): This does not appear to be properly covered by unit
114     // tests.
115     int bit_width = SecAggVector::GetBitWidth(vector_spec.modulus());
116     if (masked_vector->second.encoded_vector().size() !=
117         DivideRoundUp(vector_spec.length() * bit_width, 8)) {
118       return ::absl::InvalidArgumentError(
119           "Masked input does not match input vector specification - vector is "
120           "wrong size.");
121     }
122     checked_masked_vectors->emplace(
123         vector_spec.name(),
124         SecAggVector(std::move(*masked_vector->second.mutable_encoded_vector()),
125                      vector_spec.modulus(), vector_spec.length()));
126   }
127 
128   if (experiments()->IsEnabled(kSecAggAsyncRound2Experiment)) {
129     // If async processing is enabled we queue the client message. Moreover, if
130     // the queue we found was empty this means that it has been taken by an
131     // asynchronous aggregation task. In that case, we schedule an aggregation
132     // task to process the queue that we just initiated, which will happen
133     // eventually.
134     size_t is_queue_empty;
135     {
136       absl::MutexLock lock(&mutex_);
137       is_queue_empty = masked_input_queue_.empty();
138       masked_input_queue_.emplace_back(std::move(checked_masked_vectors));
139     }
140     if (is_queue_empty) {
141       // TODO(team): Abort should handle the situation where `this` has
142       // been destructed while the schedule task is still not running, and
143       // message_queue_ can't be moved.
144       Trace<Round2AsyncWorkScheduled>();
145       masked_input_accumulator_->Schedule([&] {
146         auto queue = TakeMaskedInputQueue();
147         Trace<Round2MessageQueueTaken>(queue.size());
148         return AddReduce(std::move(queue));
149       });
150     }
151   } else {
152     // Sequential processing
153     FCP_CHECK(masked_input_);
154     masked_input_->Add(*checked_masked_vectors);
155   }
156 
157   return ::absl::OkStatus();
158 }
159 
FinalizeMaskedInputCollection()160 void AesSecAggServerProtocolImpl::FinalizeMaskedInputCollection() {
161   if (experiments()->IsEnabled(kSecAggAsyncRound2Experiment)) {
162     FCP_CHECK(masked_input_accumulator_->IsIdle());
163     masked_input_ = masked_input_accumulator_->GetResultAndCancel();
164   }
165 }
166 
StartPrng(const PrngWorkItems & work_items,std::function<void (Status)> done_callback)167 CancellationToken AesSecAggServerProtocolImpl::StartPrng(
168     const PrngWorkItems& work_items,
169     std::function<void(Status)> done_callback) {
170   FCP_CHECK(done_callback);
171   FCP_CHECK(masked_input_);
172   auto generators =
173       std::vector<std::function<std::unique_ptr<SecAggUnpackedVectorMap>()>>();
174 
175   // Break the keys to add or subtract into vectors of size kPrngBatchSize (or
176   // less for the last one) and schedule them as tasks.
177   for (auto it = work_items.prng_keys_to_add.begin();
178        it < work_items.prng_keys_to_add.end(); it += kPrngBatchSize) {
179     std::vector<AesKey> batch_prng_keys_to_add;
180     std::copy(it,
181               std::min(it + kPrngBatchSize, work_items.prng_keys_to_add.end()),
182               std::back_inserter(batch_prng_keys_to_add));
183     generators.emplace_back([=]() {
184       return UnpackedMapOfMasks(batch_prng_keys_to_add, std::vector<AesKey>(),
185                                 input_vector_specs(), session_id(),
186                                 *prng_factory());
187     });
188   }
189 
190   for (auto it = work_items.prng_keys_to_subtract.begin();
191        it < work_items.prng_keys_to_subtract.end(); it += kPrngBatchSize) {
192     std::vector<AesKey> batch_prng_keys_to_subtract;
193     std::copy(
194         it,
195         std::min(it + kPrngBatchSize, work_items.prng_keys_to_subtract.end()),
196         std::back_inserter(batch_prng_keys_to_subtract));
197     generators.emplace_back([=]() {
198       return UnpackedMapOfMasks(
199           std::vector<AesKey>(), batch_prng_keys_to_subtract,
200           input_vector_specs(), session_id(), *prng_factory());
201     });
202   }
203 
204   auto accumulator = scheduler()->CreateAccumulator<SecAggUnpackedVectorMap>(
205       std::move(masked_input_), SecAggUnpackedVectorMap::AddMaps);
206   for (const auto& generator : generators) {
207     accumulator->Schedule(generator);
208   }
209   accumulator->SetAsyncObserver([=, accumulator = accumulator.get()]() {
210     auto unpacked_map = accumulator->GetResultAndCancel();
211     auto packed_map = std::make_unique<SecAggVectorMap>();
212     for (auto& entry : *unpacked_map) {
213       uint64_t modulus = entry.second.modulus();
214       packed_map->emplace(entry.first,
215                           SecAggVector(std::move(entry.second), modulus));
216     }
217     SetResult(std::move(packed_map));
218     done_callback(absl::OkStatus());
219   });
220   return accumulator;
221 }
222 }  // namespace secagg
223 }  // namespace fcp
224