xref: /aosp_15_r20/external/federated-compute/fcp/secagg/shared/map_of_masks.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/shared/map_of_masks.h"
18 
19 #include <algorithm>
20 #include <atomic>
21 #include <cstdint>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/numeric/bits.h"
28 #include "absl/strings/str_cat.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/secagg/shared/aes_key.h"
31 #include "fcp/secagg/shared/compute_session_id.h"
32 #include "fcp/secagg/shared/input_vector_specification.h"
33 #include "fcp/secagg/shared/math.h"
34 #include "fcp/secagg/shared/prng.h"
35 #include "fcp/secagg/shared/secagg_vector.h"
36 #include "openssl/evp.h"
37 
38 namespace fcp {
39 namespace secagg {
40 
41 // Constant for backwards compatibility with legacy clients. Even though it is
42 // no longer needed, removing it would be disruptive due to making a large
43 // number of clients incompatible while not providing any benefits.
44 uint8_t kPrngSeedConstant = 0x02;
45 
46 // We specifically avoid sample_bits == 64 to sidestep numerical precision
47 // issues, e.g. a uint64_t cannot represent the associated modulus.
48 constexpr int kMaxSampleBits = 63;
49 
50 // We consider using at most 16 additional random bits from the underlying
51 // PRNG per sample.
52 //
53 constexpr int kMaxSampleBitsExpansion = 16;
54 
DigestKey(EVP_MD_CTX * mdctx,const std::string & prng_input,int bit_width,const AesKey & prng_key)55 static AesKey DigestKey(EVP_MD_CTX* mdctx, const std::string& prng_input,
56                         int bit_width, const AesKey& prng_key) {
57   int input_size = prng_input.size();
58   std::string input_size_data = IntToByteString(input_size);
59   std::string bit_width_data = IntToByteString(bit_width);
60   FCP_CHECK(EVP_DigestInit_ex(mdctx, EVP_sha256(), nullptr));
61   FCP_CHECK(EVP_DigestUpdate(mdctx, bit_width_data.c_str(), sizeof(int)));
62   FCP_CHECK(EVP_DigestUpdate(mdctx, prng_key.data(), prng_key.size()));
63   FCP_CHECK(EVP_DigestUpdate(mdctx, &kPrngSeedConstant, 1));
64   FCP_CHECK(EVP_DigestUpdate(mdctx, input_size_data.c_str(), sizeof(int)));
65   FCP_CHECK(EVP_DigestUpdate(mdctx, prng_input.c_str(), input_size));
66 
67   uint8_t digest[AesKey::kSize];
68   uint32_t digest_length = 0;
69   FCP_CHECK(EVP_DigestFinal_ex(mdctx, digest, &digest_length));
70   FCP_CHECK(digest_length == AesKey::kSize);
71   return AesKey(digest);
72 }
73 
74 // Determines whether sample_bits_1 or sample_bits_2 will be more efficient
75 // for sampling uniformly from [0, modulus).
76 //
choose_better_sample_bits(uint64_t modulus,int sample_bits_1,int sample_bits_2)77 int choose_better_sample_bits(uint64_t modulus, int sample_bits_1,
78                               int sample_bits_2) {
79   FCP_CHECK(sample_bits_1 <= sample_bits_2);
80   FCP_CHECK(sample_bits_2 <= kMaxSampleBits);
81   FCP_CHECK(sample_bits_2 - sample_bits_1 <= kMaxSampleBitsExpansion);
82 
83   uint64_t sample_modulus_1 = 1ULL << sample_bits_1;
84   FCP_CHECK(modulus <= sample_modulus_1);
85 
86   if (sample_bits_1 == sample_bits_2) {
87     return sample_bits_1;
88   }
89 
90   uint64_t sample_modulus_2 = 1ULL << sample_bits_2;
91   uint64_t sample_modulus_2_over_1 = 1ULL << (sample_bits_2 - sample_bits_1);
92   uint32_t cost_per_sample_1 = DivideRoundUp(sample_bits_1, 8);
93   uint32_t cost_per_sample_2 = DivideRoundUp(sample_bits_2, 8);
94   uint64_t modulus_reps_1 = sample_modulus_1 / modulus;
95   uint64_t modulus_reps_2 = sample_modulus_2 / modulus;
96   uint64_t cost_product_1 = cost_per_sample_1 * modulus_reps_1;
97   uint64_t cost_product_2 =
98       cost_per_sample_2 * modulus_reps_2 * sample_modulus_2_over_1;
99   return cost_product_1 > cost_product_2 ? sample_bits_2 : sample_bits_1;
100 }
101 
102 // Computes the sample_bits that minimizes the expected number of bytes of
103 // randomness that will be consumed when drawing a uniform sample from
104 // [0, modulus) using our rejection sampling algorithm.
105 //
compute_best_sample_bits(uint64_t modulus)106 int compute_best_sample_bits(uint64_t modulus) {
107   int min_sample_bits = static_cast<int>(absl::bit_width(modulus - 1ULL));
108   int max_sample_bits = std::min(kMaxSampleBitsExpansion,
109                                  min_sample_bits + kMaxSampleBitsExpansion);
110   int best_sample_bits = min_sample_bits;
111   for (int sample_bits = min_sample_bits + 1; sample_bits <= max_sample_bits;
112        sample_bits++) {
113     best_sample_bits =
114         choose_better_sample_bits(modulus, best_sample_bits, sample_bits);
115   }
116   return best_sample_bits;
117 }
118 
119 // PrngBuffer implements the logic for generating pseudo-random masks while
120 // fetching and caching buffers of psedo-random uint8_t numbers.
121 // Two important factors of this implementation compared to using SecurePrng
122 // directly are:
123 // 1) The implementation is fully inlineable allowing the the compiler to
124 //    greatly optimize the resulting code.
125 // 2) Checking whether a new buffer of pseudo-random bytes needs to be filled is
126 //    done only once per mask as opposed to doing that for every byte, which
127 //    optimizes the most nested loop.
128 class PrngBuffer {
129  public:
PrngBuffer(std::unique_ptr<SecurePrng> prng,uint8_t msb_mask,size_t bytes_per_output)130   PrngBuffer(std::unique_ptr<SecurePrng> prng, uint8_t msb_mask,
131              size_t bytes_per_output)
132       : prng_(static_cast<SecureBatchPrng*>(prng.release())),
133         msb_mask_(msb_mask),
134         bytes_per_output_(bytes_per_output),
135         buffer_(prng_->GetMaxBufferSize()),
136         buffer_end_(buffer_.data() + buffer_.size()) {
137     FCP_CHECK((prng_->GetMaxBufferSize() % bytes_per_output) == 0)
138         << "PRNG buffer size must be a multiple bytes_per_output.";
139     FillBuffer();
140   }
141 
NextMask()142   inline uint64_t NextMask() {
143     if (buffer_ptr_ == buffer_end_) {
144       FillBuffer();
145     }
146 
147     auto output = static_cast<uint64_t>((*buffer_ptr_++) & msb_mask_);
148     for (size_t i = 1; i < bytes_per_output_; ++i) {
149       output <<= 8UL;
150       output |= static_cast<uint64_t>(*buffer_ptr_++);
151     }
152     return output;
153   }
154 
155  private:
buffer_size()156   inline int buffer_size() { return static_cast<int>(buffer_.size()); }
157 
FillBuffer()158   inline void FillBuffer() {
159     buffer_ptr_ = buffer_.data();
160     FCP_CHECK(prng_->RandBuffer(buffer_.data(), buffer_size()) ==
161               buffer_size());
162   }
163 
164   std::unique_ptr<SecureBatchPrng> prng_;
165   const uint8_t msb_mask_;
166   const size_t bytes_per_output_;
167   std::vector<uint8_t> buffer_;
168   const uint8_t* buffer_ptr_ = nullptr;
169   const uint8_t* const buffer_end_;
170 };
171 
172 struct AddModAdapter {
AddModImplfcp::secagg::AddModAdapter173   inline static uint64_t AddModImpl(uint64_t a, uint64_t b, uint64_t z) {
174     return AddMod(a, b, z);
175   }
SubtractModImplfcp::secagg::AddModAdapter176   inline static uint64_t SubtractModImpl(uint64_t a, uint64_t b, uint64_t z) {
177     return SubtractMod(a, b, z);
178   }
179 };
180 
181 struct AddModOptAdapter {
AddModImplfcp::secagg::AddModOptAdapter182   inline static uint64_t AddModImpl(uint64_t a, uint64_t b, uint64_t z) {
183     return AddModOpt(a, b, z);
184   }
SubtractModImplfcp::secagg::AddModOptAdapter185   inline static uint64_t SubtractModImpl(uint64_t a, uint64_t b, uint64_t z) {
186     return SubtractModOpt(a, b, z);
187   }
188 };
189 
190 // Templated implementation of MapOfMasks that allows substituting
191 // AddMod and SubtractMod implementations.
192 template <typename TAdapter, typename TVector, typename TVectorMap>
MapOfMasksImpl(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory,AsyncAbort * async_abort)193 inline std::unique_ptr<TVectorMap> MapOfMasksImpl(
194     const std::vector<AesKey>& prng_keys_to_add,
195     const std::vector<AesKey>& prng_keys_to_subtract,
196     const std::vector<InputVectorSpecification>& input_vector_specs,
197     const SessionId& session_id, const AesPrngFactory& prng_factory,
198     AsyncAbort* async_abort) {
199   FCP_CHECK(prng_factory.SupportsBatchMode());
200 
201   auto map_of_masks = std::make_unique<TVectorMap>();
202   std::unique_ptr<EVP_MD_CTX, void (*)(EVP_MD_CTX*)> mdctx(EVP_MD_CTX_create(),
203                                                            EVP_MD_CTX_destroy);
204   FCP_CHECK(mdctx.get());
205   for (const InputVectorSpecification& vector_spec : input_vector_specs) {
206     if (async_abort && async_abort->Signalled()) return nullptr;
207     int bit_width =
208         static_cast<int>(absl::bit_width(vector_spec.modulus() - 1ULL));
209     std::string prng_input =
210         absl::StrCat(session_id.data, IntToByteString(bit_width),
211                      IntToByteString(vector_spec.length()), vector_spec.name());
212     std::vector<uint64_t> mask_vector_buffer(vector_spec.length(), 0);
213 
214     bool modulus_is_power_of_two = (1ULL << bit_width == vector_spec.modulus());
215     if (modulus_is_power_of_two) {
216       // Because the modulus is a power of two, we can sample uniformly
217       // simply by drawing the correct number of random bits.
218       int bytes_per_output = DivideRoundUp(bit_width, 8);
219       // msb = "most significant byte"
220       size_t bits_in_msb = bit_width - ((bytes_per_output - 1) * 8);
221       uint8_t msb_mask = (1UL << bits_in_msb) - 1;
222 
223       for (const auto& prng_key : prng_keys_to_add) {
224         if (async_abort && async_abort->Signalled()) return nullptr;
225         AesKey digest_key =
226             DigestKey(mdctx.get(), prng_input, bit_width, prng_key);
227         PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
228                         bytes_per_output);
229         for (auto& v : mask_vector_buffer) {
230           v = TAdapter::AddModImpl(v, prng.NextMask(), vector_spec.modulus());
231         }
232       }
233 
234       for (const auto& prng_key : prng_keys_to_subtract) {
235         if (async_abort && async_abort->Signalled()) return nullptr;
236         AesKey digest_key =
237             DigestKey(mdctx.get(), prng_input, bit_width, prng_key);
238         PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
239                         bytes_per_output);
240         for (auto& v : mask_vector_buffer) {
241           v = TAdapter::SubtractModImpl(v, prng.NextMask(),
242                                         vector_spec.modulus());
243         }
244       }
245     } else {
246       // Rejection Sampling algorithm for arbitrary moduli.
247       // Follows Algorithm 3 from:
248       // "Fast Random Integer Generation in an Interval," Daniel Lemire, 2018.
249       // https://arxiv.org/pdf/1805.10941.pdf.
250       //
251       // The inner loops are structured to avoid conditional branches
252       // and the associated branch misprediction errors they would entail.
253       //
254       // We choose sample_bits to minimize the expected number of bytes
255       // drawn from the PRNG.
256 
257       int sample_bits = compute_best_sample_bits(vector_spec.modulus());
258       int bytes_per_output = DivideRoundUp(sample_bits, 8);
259       // msb = "most significant byte"
260       size_t bits_in_msb = sample_bits - ((bytes_per_output - 1) * 8);
261       uint8_t msb_mask = (1UL << bits_in_msb) - 1;
262 
263       uint64_t sample_modulus = 1ULL << sample_bits;
264       uint64_t rejection_threshold =
265           (sample_modulus - vector_spec.modulus()) % vector_spec.modulus();
266 
267       for (const auto& prng_key : prng_keys_to_add) {
268         if (async_abort && async_abort->Signalled()) return nullptr;
269         AesKey digest_key =
270             DigestKey(mdctx.get(), prng_input, sample_bits, prng_key);
271         PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
272                         bytes_per_output);
273         int i = 0;
274         while (i < vector_spec.length()) {
275           auto& v = mask_vector_buffer[i];
276           auto mask = prng.NextMask();
277           auto reject = mask < rejection_threshold;
278           auto inc = reject ? 0 : 1;
279           mask = reject ? 0 : mask;
280           v = TAdapter::AddModImpl(v, mask % vector_spec.modulus(),
281                                    vector_spec.modulus());
282           i += inc;
283         }
284       }
285 
286       for (const auto& prng_key : prng_keys_to_subtract) {
287         if (async_abort && async_abort->Signalled()) return nullptr;
288         AesKey digest_key =
289             DigestKey(mdctx.get(), prng_input, sample_bits, prng_key);
290         PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
291                         bytes_per_output);
292         int i = 0;
293         while (i < vector_spec.length()) {
294           auto& v = mask_vector_buffer[i];
295           auto mask = prng.NextMask();
296           auto reject = mask < rejection_threshold;
297           auto inc = reject ? 0 : 1;
298           mask = reject ? 0 : mask;
299           v = TAdapter::SubtractModImpl(v, mask % vector_spec.modulus(),
300                                         vector_spec.modulus());
301           i += inc;
302         }
303       }
304     }
305 
306     if (async_abort && async_abort->Signalled()) return nullptr;
307     map_of_masks->emplace(vector_spec.name(),
308                           TVector(mask_vector_buffer, vector_spec.modulus()));
309   }
310   return map_of_masks;
311 }
312 
MapOfMasks(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory,AsyncAbort * async_abort)313 std::unique_ptr<SecAggVectorMap> MapOfMasks(
314     const std::vector<AesKey>& prng_keys_to_add,
315     const std::vector<AesKey>& prng_keys_to_subtract,
316     const std::vector<InputVectorSpecification>& input_vector_specs,
317     const SessionId& session_id, const AesPrngFactory& prng_factory,
318     AsyncAbort* async_abort) {
319   return MapOfMasksImpl<AddModAdapter, SecAggVector, SecAggVectorMap>(
320       prng_keys_to_add, prng_keys_to_subtract, input_vector_specs, session_id,
321       prng_factory, async_abort);
322 }
323 
MapOfMasksV3(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory,AsyncAbort * async_abort)324 std::unique_ptr<SecAggVectorMap> MapOfMasksV3(
325     const std::vector<AesKey>& prng_keys_to_add,
326     const std::vector<AesKey>& prng_keys_to_subtract,
327     const std::vector<InputVectorSpecification>& input_vector_specs,
328     const SessionId& session_id, const AesPrngFactory& prng_factory,
329     AsyncAbort* async_abort) {
330   return MapOfMasksImpl<AddModOptAdapter, SecAggVector, SecAggVectorMap>(
331       prng_keys_to_add, prng_keys_to_subtract, input_vector_specs, session_id,
332       prng_factory, async_abort);
333 }
334 
AddVectors(const SecAggVector & a,const SecAggVector & b)335 SecAggVector AddVectors(const SecAggVector& a, const SecAggVector& b) {
336   FCP_CHECK(a.modulus() == b.modulus() && a.num_elements() == b.num_elements());
337   uint64_t modulus = a.modulus();
338   SecAggVector::Decoder decoder_a(a);
339   SecAggVector::Decoder decoder_b(b);
340   SecAggVector::Coder sum_coder(modulus, static_cast<int>(a.bit_width()),
341                                 a.num_elements());
342   for (int remaining_elements = static_cast<int>(a.num_elements());
343        remaining_elements > 0; --remaining_elements) {
344     sum_coder.WriteValue((decoder_a.ReadValue() + decoder_b.ReadValue()) %
345                          modulus);
346   }
347   return std::move(sum_coder).Create();
348 }
349 
AddMaps(const SecAggVectorMap & a,const SecAggVectorMap & b)350 std::unique_ptr<SecAggVectorMap> AddMaps(const SecAggVectorMap& a,
351                                          const SecAggVectorMap& b) {
352   auto result = std::make_unique<SecAggVectorMap>();
353   for (const auto& item : a) {
354     result->emplace(item.first, AddVectors(item.second, b.at(item.first)));
355   }
356   return result;
357 }
358 
UnpackedMapOfMasks(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory,AsyncAbort * async_abort)359 std::unique_ptr<SecAggUnpackedVectorMap> UnpackedMapOfMasks(
360     const std::vector<AesKey>& prng_keys_to_add,
361     const std::vector<AesKey>& prng_keys_to_subtract,
362     const std::vector<InputVectorSpecification>& input_vector_specs,
363     const SessionId& session_id, const AesPrngFactory& prng_factory,
364     AsyncAbort* async_abort) {
365   return MapOfMasksImpl<AddModOptAdapter, SecAggUnpackedVector,
366                         SecAggUnpackedVectorMap>(
367       prng_keys_to_add, prng_keys_to_subtract, input_vector_specs, session_id,
368       prng_factory, async_abort);
369 }
370 
371 }  // namespace secagg
372 }  // namespace fcp
373