1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/shared/map_of_masks.h"
18
19 #include <algorithm>
20 #include <atomic>
21 #include <cstdint>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/numeric/bits.h"
28 #include "absl/strings/str_cat.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/secagg/shared/aes_key.h"
31 #include "fcp/secagg/shared/compute_session_id.h"
32 #include "fcp/secagg/shared/input_vector_specification.h"
33 #include "fcp/secagg/shared/math.h"
34 #include "fcp/secagg/shared/prng.h"
35 #include "fcp/secagg/shared/secagg_vector.h"
36 #include "openssl/evp.h"
37
38 namespace fcp {
39 namespace secagg {
40
41 // Constant for backwards compatibility with legacy clients. Even though it is
42 // no longer needed, removing it would be disruptive due to making a large
43 // number of clients incompatible while not providing any benefits.
44 uint8_t kPrngSeedConstant = 0x02;
45
46 // We specifically avoid sample_bits == 64 to sidestep numerical precision
47 // issues, e.g. a uint64_t cannot represent the associated modulus.
48 constexpr int kMaxSampleBits = 63;
49
50 // We consider using at most 16 additional random bits from the underlying
51 // PRNG per sample.
52 //
53 constexpr int kMaxSampleBitsExpansion = 16;
54
DigestKey(EVP_MD_CTX * mdctx,const std::string & prng_input,int bit_width,const AesKey & prng_key)55 static AesKey DigestKey(EVP_MD_CTX* mdctx, const std::string& prng_input,
56 int bit_width, const AesKey& prng_key) {
57 int input_size = prng_input.size();
58 std::string input_size_data = IntToByteString(input_size);
59 std::string bit_width_data = IntToByteString(bit_width);
60 FCP_CHECK(EVP_DigestInit_ex(mdctx, EVP_sha256(), nullptr));
61 FCP_CHECK(EVP_DigestUpdate(mdctx, bit_width_data.c_str(), sizeof(int)));
62 FCP_CHECK(EVP_DigestUpdate(mdctx, prng_key.data(), prng_key.size()));
63 FCP_CHECK(EVP_DigestUpdate(mdctx, &kPrngSeedConstant, 1));
64 FCP_CHECK(EVP_DigestUpdate(mdctx, input_size_data.c_str(), sizeof(int)));
65 FCP_CHECK(EVP_DigestUpdate(mdctx, prng_input.c_str(), input_size));
66
67 uint8_t digest[AesKey::kSize];
68 uint32_t digest_length = 0;
69 FCP_CHECK(EVP_DigestFinal_ex(mdctx, digest, &digest_length));
70 FCP_CHECK(digest_length == AesKey::kSize);
71 return AesKey(digest);
72 }
73
74 // Determines whether sample_bits_1 or sample_bits_2 will be more efficient
75 // for sampling uniformly from [0, modulus).
76 //
choose_better_sample_bits(uint64_t modulus,int sample_bits_1,int sample_bits_2)77 int choose_better_sample_bits(uint64_t modulus, int sample_bits_1,
78 int sample_bits_2) {
79 FCP_CHECK(sample_bits_1 <= sample_bits_2);
80 FCP_CHECK(sample_bits_2 <= kMaxSampleBits);
81 FCP_CHECK(sample_bits_2 - sample_bits_1 <= kMaxSampleBitsExpansion);
82
83 uint64_t sample_modulus_1 = 1ULL << sample_bits_1;
84 FCP_CHECK(modulus <= sample_modulus_1);
85
86 if (sample_bits_1 == sample_bits_2) {
87 return sample_bits_1;
88 }
89
90 uint64_t sample_modulus_2 = 1ULL << sample_bits_2;
91 uint64_t sample_modulus_2_over_1 = 1ULL << (sample_bits_2 - sample_bits_1);
92 uint32_t cost_per_sample_1 = DivideRoundUp(sample_bits_1, 8);
93 uint32_t cost_per_sample_2 = DivideRoundUp(sample_bits_2, 8);
94 uint64_t modulus_reps_1 = sample_modulus_1 / modulus;
95 uint64_t modulus_reps_2 = sample_modulus_2 / modulus;
96 uint64_t cost_product_1 = cost_per_sample_1 * modulus_reps_1;
97 uint64_t cost_product_2 =
98 cost_per_sample_2 * modulus_reps_2 * sample_modulus_2_over_1;
99 return cost_product_1 > cost_product_2 ? sample_bits_2 : sample_bits_1;
100 }
101
102 // Computes the sample_bits that minimizes the expected number of bytes of
103 // randomness that will be consumed when drawing a uniform sample from
104 // [0, modulus) using our rejection sampling algorithm.
105 //
compute_best_sample_bits(uint64_t modulus)106 int compute_best_sample_bits(uint64_t modulus) {
107 int min_sample_bits = static_cast<int>(absl::bit_width(modulus - 1ULL));
108 int max_sample_bits = std::min(kMaxSampleBitsExpansion,
109 min_sample_bits + kMaxSampleBitsExpansion);
110 int best_sample_bits = min_sample_bits;
111 for (int sample_bits = min_sample_bits + 1; sample_bits <= max_sample_bits;
112 sample_bits++) {
113 best_sample_bits =
114 choose_better_sample_bits(modulus, best_sample_bits, sample_bits);
115 }
116 return best_sample_bits;
117 }
118
119 // PrngBuffer implements the logic for generating pseudo-random masks while
120 // fetching and caching buffers of psedo-random uint8_t numbers.
121 // Two important factors of this implementation compared to using SecurePrng
122 // directly are:
123 // 1) The implementation is fully inlineable allowing the the compiler to
124 // greatly optimize the resulting code.
125 // 2) Checking whether a new buffer of pseudo-random bytes needs to be filled is
126 // done only once per mask as opposed to doing that for every byte, which
127 // optimizes the most nested loop.
128 class PrngBuffer {
129 public:
PrngBuffer(std::unique_ptr<SecurePrng> prng,uint8_t msb_mask,size_t bytes_per_output)130 PrngBuffer(std::unique_ptr<SecurePrng> prng, uint8_t msb_mask,
131 size_t bytes_per_output)
132 : prng_(static_cast<SecureBatchPrng*>(prng.release())),
133 msb_mask_(msb_mask),
134 bytes_per_output_(bytes_per_output),
135 buffer_(prng_->GetMaxBufferSize()),
136 buffer_end_(buffer_.data() + buffer_.size()) {
137 FCP_CHECK((prng_->GetMaxBufferSize() % bytes_per_output) == 0)
138 << "PRNG buffer size must be a multiple bytes_per_output.";
139 FillBuffer();
140 }
141
NextMask()142 inline uint64_t NextMask() {
143 if (buffer_ptr_ == buffer_end_) {
144 FillBuffer();
145 }
146
147 auto output = static_cast<uint64_t>((*buffer_ptr_++) & msb_mask_);
148 for (size_t i = 1; i < bytes_per_output_; ++i) {
149 output <<= 8UL;
150 output |= static_cast<uint64_t>(*buffer_ptr_++);
151 }
152 return output;
153 }
154
155 private:
buffer_size()156 inline int buffer_size() { return static_cast<int>(buffer_.size()); }
157
FillBuffer()158 inline void FillBuffer() {
159 buffer_ptr_ = buffer_.data();
160 FCP_CHECK(prng_->RandBuffer(buffer_.data(), buffer_size()) ==
161 buffer_size());
162 }
163
164 std::unique_ptr<SecureBatchPrng> prng_;
165 const uint8_t msb_mask_;
166 const size_t bytes_per_output_;
167 std::vector<uint8_t> buffer_;
168 const uint8_t* buffer_ptr_ = nullptr;
169 const uint8_t* const buffer_end_;
170 };
171
172 struct AddModAdapter {
AddModImplfcp::secagg::AddModAdapter173 inline static uint64_t AddModImpl(uint64_t a, uint64_t b, uint64_t z) {
174 return AddMod(a, b, z);
175 }
SubtractModImplfcp::secagg::AddModAdapter176 inline static uint64_t SubtractModImpl(uint64_t a, uint64_t b, uint64_t z) {
177 return SubtractMod(a, b, z);
178 }
179 };
180
181 struct AddModOptAdapter {
AddModImplfcp::secagg::AddModOptAdapter182 inline static uint64_t AddModImpl(uint64_t a, uint64_t b, uint64_t z) {
183 return AddModOpt(a, b, z);
184 }
SubtractModImplfcp::secagg::AddModOptAdapter185 inline static uint64_t SubtractModImpl(uint64_t a, uint64_t b, uint64_t z) {
186 return SubtractModOpt(a, b, z);
187 }
188 };
189
190 // Templated implementation of MapOfMasks that allows substituting
191 // AddMod and SubtractMod implementations.
192 template <typename TAdapter, typename TVector, typename TVectorMap>
MapOfMasksImpl(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory,AsyncAbort * async_abort)193 inline std::unique_ptr<TVectorMap> MapOfMasksImpl(
194 const std::vector<AesKey>& prng_keys_to_add,
195 const std::vector<AesKey>& prng_keys_to_subtract,
196 const std::vector<InputVectorSpecification>& input_vector_specs,
197 const SessionId& session_id, const AesPrngFactory& prng_factory,
198 AsyncAbort* async_abort) {
199 FCP_CHECK(prng_factory.SupportsBatchMode());
200
201 auto map_of_masks = std::make_unique<TVectorMap>();
202 std::unique_ptr<EVP_MD_CTX, void (*)(EVP_MD_CTX*)> mdctx(EVP_MD_CTX_create(),
203 EVP_MD_CTX_destroy);
204 FCP_CHECK(mdctx.get());
205 for (const InputVectorSpecification& vector_spec : input_vector_specs) {
206 if (async_abort && async_abort->Signalled()) return nullptr;
207 int bit_width =
208 static_cast<int>(absl::bit_width(vector_spec.modulus() - 1ULL));
209 std::string prng_input =
210 absl::StrCat(session_id.data, IntToByteString(bit_width),
211 IntToByteString(vector_spec.length()), vector_spec.name());
212 std::vector<uint64_t> mask_vector_buffer(vector_spec.length(), 0);
213
214 bool modulus_is_power_of_two = (1ULL << bit_width == vector_spec.modulus());
215 if (modulus_is_power_of_two) {
216 // Because the modulus is a power of two, we can sample uniformly
217 // simply by drawing the correct number of random bits.
218 int bytes_per_output = DivideRoundUp(bit_width, 8);
219 // msb = "most significant byte"
220 size_t bits_in_msb = bit_width - ((bytes_per_output - 1) * 8);
221 uint8_t msb_mask = (1UL << bits_in_msb) - 1;
222
223 for (const auto& prng_key : prng_keys_to_add) {
224 if (async_abort && async_abort->Signalled()) return nullptr;
225 AesKey digest_key =
226 DigestKey(mdctx.get(), prng_input, bit_width, prng_key);
227 PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
228 bytes_per_output);
229 for (auto& v : mask_vector_buffer) {
230 v = TAdapter::AddModImpl(v, prng.NextMask(), vector_spec.modulus());
231 }
232 }
233
234 for (const auto& prng_key : prng_keys_to_subtract) {
235 if (async_abort && async_abort->Signalled()) return nullptr;
236 AesKey digest_key =
237 DigestKey(mdctx.get(), prng_input, bit_width, prng_key);
238 PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
239 bytes_per_output);
240 for (auto& v : mask_vector_buffer) {
241 v = TAdapter::SubtractModImpl(v, prng.NextMask(),
242 vector_spec.modulus());
243 }
244 }
245 } else {
246 // Rejection Sampling algorithm for arbitrary moduli.
247 // Follows Algorithm 3 from:
248 // "Fast Random Integer Generation in an Interval," Daniel Lemire, 2018.
249 // https://arxiv.org/pdf/1805.10941.pdf.
250 //
251 // The inner loops are structured to avoid conditional branches
252 // and the associated branch misprediction errors they would entail.
253 //
254 // We choose sample_bits to minimize the expected number of bytes
255 // drawn from the PRNG.
256
257 int sample_bits = compute_best_sample_bits(vector_spec.modulus());
258 int bytes_per_output = DivideRoundUp(sample_bits, 8);
259 // msb = "most significant byte"
260 size_t bits_in_msb = sample_bits - ((bytes_per_output - 1) * 8);
261 uint8_t msb_mask = (1UL << bits_in_msb) - 1;
262
263 uint64_t sample_modulus = 1ULL << sample_bits;
264 uint64_t rejection_threshold =
265 (sample_modulus - vector_spec.modulus()) % vector_spec.modulus();
266
267 for (const auto& prng_key : prng_keys_to_add) {
268 if (async_abort && async_abort->Signalled()) return nullptr;
269 AesKey digest_key =
270 DigestKey(mdctx.get(), prng_input, sample_bits, prng_key);
271 PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
272 bytes_per_output);
273 int i = 0;
274 while (i < vector_spec.length()) {
275 auto& v = mask_vector_buffer[i];
276 auto mask = prng.NextMask();
277 auto reject = mask < rejection_threshold;
278 auto inc = reject ? 0 : 1;
279 mask = reject ? 0 : mask;
280 v = TAdapter::AddModImpl(v, mask % vector_spec.modulus(),
281 vector_spec.modulus());
282 i += inc;
283 }
284 }
285
286 for (const auto& prng_key : prng_keys_to_subtract) {
287 if (async_abort && async_abort->Signalled()) return nullptr;
288 AesKey digest_key =
289 DigestKey(mdctx.get(), prng_input, sample_bits, prng_key);
290 PrngBuffer prng(prng_factory.MakePrng(digest_key), msb_mask,
291 bytes_per_output);
292 int i = 0;
293 while (i < vector_spec.length()) {
294 auto& v = mask_vector_buffer[i];
295 auto mask = prng.NextMask();
296 auto reject = mask < rejection_threshold;
297 auto inc = reject ? 0 : 1;
298 mask = reject ? 0 : mask;
299 v = TAdapter::SubtractModImpl(v, mask % vector_spec.modulus(),
300 vector_spec.modulus());
301 i += inc;
302 }
303 }
304 }
305
306 if (async_abort && async_abort->Signalled()) return nullptr;
307 map_of_masks->emplace(vector_spec.name(),
308 TVector(mask_vector_buffer, vector_spec.modulus()));
309 }
310 return map_of_masks;
311 }
312
MapOfMasks(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory,AsyncAbort * async_abort)313 std::unique_ptr<SecAggVectorMap> MapOfMasks(
314 const std::vector<AesKey>& prng_keys_to_add,
315 const std::vector<AesKey>& prng_keys_to_subtract,
316 const std::vector<InputVectorSpecification>& input_vector_specs,
317 const SessionId& session_id, const AesPrngFactory& prng_factory,
318 AsyncAbort* async_abort) {
319 return MapOfMasksImpl<AddModAdapter, SecAggVector, SecAggVectorMap>(
320 prng_keys_to_add, prng_keys_to_subtract, input_vector_specs, session_id,
321 prng_factory, async_abort);
322 }
323
MapOfMasksV3(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory,AsyncAbort * async_abort)324 std::unique_ptr<SecAggVectorMap> MapOfMasksV3(
325 const std::vector<AesKey>& prng_keys_to_add,
326 const std::vector<AesKey>& prng_keys_to_subtract,
327 const std::vector<InputVectorSpecification>& input_vector_specs,
328 const SessionId& session_id, const AesPrngFactory& prng_factory,
329 AsyncAbort* async_abort) {
330 return MapOfMasksImpl<AddModOptAdapter, SecAggVector, SecAggVectorMap>(
331 prng_keys_to_add, prng_keys_to_subtract, input_vector_specs, session_id,
332 prng_factory, async_abort);
333 }
334
AddVectors(const SecAggVector & a,const SecAggVector & b)335 SecAggVector AddVectors(const SecAggVector& a, const SecAggVector& b) {
336 FCP_CHECK(a.modulus() == b.modulus() && a.num_elements() == b.num_elements());
337 uint64_t modulus = a.modulus();
338 SecAggVector::Decoder decoder_a(a);
339 SecAggVector::Decoder decoder_b(b);
340 SecAggVector::Coder sum_coder(modulus, static_cast<int>(a.bit_width()),
341 a.num_elements());
342 for (int remaining_elements = static_cast<int>(a.num_elements());
343 remaining_elements > 0; --remaining_elements) {
344 sum_coder.WriteValue((decoder_a.ReadValue() + decoder_b.ReadValue()) %
345 modulus);
346 }
347 return std::move(sum_coder).Create();
348 }
349
AddMaps(const SecAggVectorMap & a,const SecAggVectorMap & b)350 std::unique_ptr<SecAggVectorMap> AddMaps(const SecAggVectorMap& a,
351 const SecAggVectorMap& b) {
352 auto result = std::make_unique<SecAggVectorMap>();
353 for (const auto& item : a) {
354 result->emplace(item.first, AddVectors(item.second, b.at(item.first)));
355 }
356 return result;
357 }
358
UnpackedMapOfMasks(const std::vector<AesKey> & prng_keys_to_add,const std::vector<AesKey> & prng_keys_to_subtract,const std::vector<InputVectorSpecification> & input_vector_specs,const SessionId & session_id,const AesPrngFactory & prng_factory,AsyncAbort * async_abort)359 std::unique_ptr<SecAggUnpackedVectorMap> UnpackedMapOfMasks(
360 const std::vector<AesKey>& prng_keys_to_add,
361 const std::vector<AesKey>& prng_keys_to_subtract,
362 const std::vector<InputVectorSpecification>& input_vector_specs,
363 const SessionId& session_id, const AesPrngFactory& prng_factory,
364 AsyncAbort* async_abort) {
365 return MapOfMasksImpl<AddModOptAdapter, SecAggUnpackedVector,
366 SecAggUnpackedVectorMap>(
367 prng_keys_to_add, prng_keys_to_subtract, input_vector_specs, session_id,
368 prng_factory, async_abort);
369 }
370
371 } // namespace secagg
372 } // namespace fcp
373