1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/shared/secagg_vector.h"
18
19 #include <inttypes.h>
20
21 #include <algorithm>
22 #include <array>
23 #include <climits>
24 #include <cstdint>
25 #include <cstring>
26 #include <memory>
27 #include <string>
28 #include <utility>
29 #include <vector>
30
31 #include "absl/strings/string_view.h"
32 #include "absl/types/span.h"
33 #include "fcp/base/monitoring.h"
34 #include "fcp/secagg/shared/math.h"
35
36 namespace fcp {
37 namespace secagg {
38
39 const uint64_t SecAggVector::kMaxModulus;
40
SecAggVector(absl::Span<const uint64_t> span,uint64_t modulus,bool branchless_codec)41 SecAggVector::SecAggVector(absl::Span<const uint64_t> span, uint64_t modulus,
42 bool branchless_codec)
43 : modulus_(modulus),
44 bit_width_(SecAggVector::GetBitWidth(modulus)),
45 num_elements_(span.size()),
46 branchless_codec_(branchless_codec) {
47 FCP_CHECK(modulus_ > 1 && modulus_ <= kMaxModulus)
48 << "The specified modulus is not valid: must be > 1 and <= "
49 << kMaxModulus << "; supplied value : " << modulus_;
50 // Ensuring the supplied vector has the appropriate modulus.
51 for (uint64_t element : span) {
52 FCP_CHECK(element >= 0)
53 << "Only non negative elements are allowed in the vector.";
54 FCP_CHECK(element < modulus_)
55 << "The span does not have the appropriate modulus: element "
56 "with value "
57 << element << " found, max value allowed " << (modulus_ - 1ULL);
58 }
59
60 // Packs the long vector into a string, initialized to all null.
61 if (branchless_codec_) {
62 PackUint64IntoByteStringBranchless(span);
63 } else {
64 int num_bytes_needed =
65 DivideRoundUp(static_cast<uint32_t>(num_elements_ * bit_width_), 8);
66 packed_bytes_ = std::string(num_bytes_needed, '\0');
67 for (int i = 0; static_cast<size_t>(i) < span.size(); ++i) {
68 PackUint64IntoByteStringAt(i, span[i]);
69 }
70 }
71 }
72
SecAggVector(std::string packed_bytes,uint64_t modulus,size_t num_elements,bool branchless_codec)73 SecAggVector::SecAggVector(std::string packed_bytes, uint64_t modulus,
74 size_t num_elements, bool branchless_codec)
75 : packed_bytes_(std::move(packed_bytes)),
76 modulus_(modulus),
77 bit_width_(SecAggVector::GetBitWidth(modulus)),
78 num_elements_(num_elements),
79 branchless_codec_(branchless_codec) {
80 FCP_CHECK(modulus_ > 1 && modulus_ <= kMaxModulus)
81 << "The specified modulus is not valid: must be > 1 and <= "
82 << kMaxModulus << "; supplied value : " << modulus_;
83 int expected_num_bytes = DivideRoundUp(num_elements_ * bit_width_, 8);
84 FCP_CHECK(packed_bytes_.size() == static_cast<size_t>(expected_num_bytes))
85 << "The supplied string is not the right size for " << num_elements_
86 << " packed elements: given string has a limit of "
87 << packed_bytes_.size() << " bytes, " << expected_num_bytes
88 << " bytes would have been needed.";
89 }
90
GetAsUint64Vector() const91 std::vector<uint64_t> SecAggVector::GetAsUint64Vector() const {
92 CheckHasValue();
93 std::vector<uint64_t> long_vector;
94 if (branchless_codec_) {
95 UnpackByteStringToUint64VectorBranchless(&long_vector);
96 } else {
97 long_vector.reserve(num_elements_);
98 for (int i = 0; i < num_elements_; ++i) {
99 long_vector.push_back(
100 UnpackUint64FromByteStringAt(i, bit_width_, packed_bytes_));
101 }
102 }
103 return long_vector;
104 }
105
PackUint64IntoByteStringAt(int index,uint64_t element)106 void SecAggVector::PackUint64IntoByteStringAt(int index, uint64_t element) {
107 // The element will be packed starting with the least significant (leftmost)
108 // bits.
109 //
110 // TODO(team): Optimize out this extra per element computation.
111 int leftmost_bit_position = index * bit_width_;
112 int current_byte_index = leftmost_bit_position / 8;
113 int bits_left_to_pack = bit_width_;
114
115 // If leftmost_bit_position is in the middle of a byte, first fill that byte.
116 if (leftmost_bit_position % 8 != 0) {
117 int starting_bit_position = leftmost_bit_position % 8;
118 int empty_bits_left = 8 - starting_bit_position;
119 // Extract enough bits from "element" to fill the current byte, and shift
120 // them to the correct position.
121 uint64_t mask = (1ULL << std::min(empty_bits_left, bits_left_to_pack)) - 1L;
122 uint64_t value_to_add = (element & mask) << starting_bit_position;
123 packed_bytes_[current_byte_index] |= static_cast<char>(value_to_add);
124
125 bits_left_to_pack -= empty_bits_left;
126 element >>= empty_bits_left;
127 current_byte_index++;
128 }
129
130 // Current bit position is now aligned with the start of the current byte.
131 // Pack as many whole bytes as possible.
132 uint64_t lower_eight_bit_mask = 255L;
133 while (bits_left_to_pack >= 8) {
134 packed_bytes_[current_byte_index] =
135 static_cast<char>(element & lower_eight_bit_mask);
136
137 bits_left_to_pack -= 8;
138 element >>= 8;
139 current_byte_index++;
140 }
141
142 // Pack the remaining partial byte, if necessary.
143 if (bits_left_to_pack > 0) {
144 // there should be < 8 bits left, so pack all remaining bits at once.
145 packed_bytes_[current_byte_index] |= static_cast<char>(element);
146 }
147 }
148
UnpackUint64FromByteStringAt(int index,int bit_width,const std::string & byte_string)149 uint64_t SecAggVector::UnpackUint64FromByteStringAt(
150 int index, int bit_width, const std::string& byte_string) {
151 // all the bits starting from, and including, this bit are copied.
152 int leftmost_bit_position = index * bit_width;
153 // byte containing the lowest order bit to be copied
154 int leftmost_byte_index = leftmost_bit_position / 8;
155 // all bits up to, but not including this bit, are copied.
156 int right_boundary_bit_position = ((index + 1) * bit_width);
157 // byte containing the highest order bit to copy
158 int rightmost_byte_index = (right_boundary_bit_position - 1) / 8;
159
160 // Special case: when the entire long value to unpack is contained in a single
161 // byte, then extract that long value in a single step.
162 if (leftmost_byte_index == rightmost_byte_index) {
163 int num_bits_to_skip = (leftmost_bit_position % 8);
164 int mask = ((1 << bit_width) - 1) << num_bits_to_skip;
165 // drop the extraneous bits below and above the value to unpack.
166 uint64_t unpacked_element =
167 (byte_string[leftmost_byte_index] & mask) >> num_bits_to_skip;
168 return unpacked_element;
169 }
170
171 // Normal case: the value to unpack spans one or more byte boundaries.
172 // The element will be unpacked in reverse order, starting from the most
173 // significant (rightmost) bits.
174 int current_byte_index = rightmost_byte_index;
175 uint64_t unpacked_element = 0;
176 int bits_left_to_unpack = bit_width;
177
178 // If right_boundary_bit_position is in the middle of a byte, unpack the bits
179 // up to right_boundary_bit_position within that byte.
180 if (right_boundary_bit_position % 8 != 0) {
181 int bits_to_copy_from_current_byte = (right_boundary_bit_position % 8);
182 int lower_bits_mask = (1 << bits_to_copy_from_current_byte) - 1;
183 unpacked_element |= (byte_string[current_byte_index] & lower_bits_mask);
184
185 bits_left_to_unpack -= bits_to_copy_from_current_byte;
186 current_byte_index--;
187 }
188
189 // Current bit position is now aligned with a byte boundary. Unpack as many
190 // whole bytes as possible.
191 while (bits_left_to_unpack >= 8) {
192 unpacked_element <<= 8;
193 unpacked_element |= byte_string[current_byte_index] & 0xff;
194
195 bits_left_to_unpack -= 8;
196 current_byte_index--;
197 }
198
199 // Unpack the remaining partial byte, if necessary.
200 if (bits_left_to_unpack > 0) {
201 unpacked_element <<= bits_left_to_unpack;
202 int bits_to_skip_in_current_byte = 8 - bits_left_to_unpack;
203 unpacked_element |= (byte_string[current_byte_index] & 0xff) >>
204 bits_to_skip_in_current_byte;
205 }
206
207 return unpacked_element;
208 }
209
PackUint64IntoByteStringBranchless(const absl::Span<const uint64_t> span)210 void SecAggVector::PackUint64IntoByteStringBranchless(
211 const absl::Span<const uint64_t> span) {
212 SecAggVector::Coder coder(modulus_, bit_width_, num_elements_);
213 for (uint64_t element : span) {
214 coder.WriteValue(element);
215 }
216 packed_bytes_ = std::move(coder).Create().TakePackedBytes();
217 }
218
UnpackByteStringToUint64VectorBranchless(std::vector<uint64_t> * long_vector) const219 void SecAggVector::UnpackByteStringToUint64VectorBranchless(
220 std::vector<uint64_t>* long_vector) const {
221 long_vector->resize(num_elements_);
222 Decoder decoder(*this);
223 for (uint64_t& element : *long_vector) {
224 element = decoder.ReadValue();
225 }
226 }
227
Decoder(absl::string_view packed_bytes,uint64_t modulus)228 SecAggVector::Decoder::Decoder(absl::string_view packed_bytes, uint64_t modulus)
229 : read_cursor_(packed_bytes.data()),
230 cursor_sentinel_(packed_bytes.data() + packed_bytes.size()),
231 cursor_read_value_(0),
232 scratch_(0),
233 read_cursor_bit_(0),
234 bit_width_(SecAggVector::GetBitWidth(modulus)),
235 mask_((1ULL << bit_width_) - 1),
236 modulus_(modulus) {
237 ReadData();
238 }
239
ReadData()240 inline void SecAggVector::Decoder::ReadData() {
241 static constexpr ssize_t kBlockSizeBytes = sizeof(cursor_read_value_);
242 const ptrdiff_t bytes_remaining = cursor_sentinel_ - read_cursor_;
243 // Here, we use memcpy() to avoid the undefined behavior of
244 // reinterpret_cast<> on unaligned reads, opportunistically reading up to
245 // eight bytes at a time.
246 if (bytes_remaining >= kBlockSizeBytes) {
247 memcpy(&cursor_read_value_, read_cursor_, kBlockSizeBytes);
248 } else {
249 memcpy(&cursor_read_value_, read_cursor_,
250 bytes_remaining > 0 ? bytes_remaining : 0);
251 }
252 scratch_ |= cursor_read_value_ << static_cast<unsigned>(read_cursor_bit_);
253 }
254
ReadValue()255 uint64_t SecAggVector::Decoder::ReadValue() {
256 static constexpr int kBlockSizeBits = sizeof(cursor_read_value_) * 8;
257 // Get the current value.
258 const uint64_t current_value = scratch_ & mask_;
259 // Advance to the next value.
260 scratch_ >>= bit_width_;
261 int unwritten_bits = read_cursor_bit_;
262 read_cursor_bit_ -= bit_width_;
263 // Because we read in eight byte chunks on byte boundaries, and only keep
264 // eight bytes of scratch, a portion of the read could not fit, and now
265 // belongs at the back of scratch. The following assignments are compiled
266 // to a branchless conditional move on Clang X86_64 and Clang ARMv{7, 8}.
267 int read_bit_shift = bit_width_ - unwritten_bits;
268 unsigned int right_shift_value = read_bit_shift > 0 ? read_bit_shift : 0;
269 unsigned int left_shift_value = read_bit_shift < 0 ? -read_bit_shift : 0;
270 cursor_read_value_ >>= right_shift_value;
271 cursor_read_value_ <<= left_shift_value;
272 scratch_ |= cursor_read_value_;
273 int valid_scratch_bits = kBlockSizeBits - bit_width_ + unwritten_bits;
274 valid_scratch_bits = (valid_scratch_bits > kBlockSizeBits)
275 ? kBlockSizeBits
276 : valid_scratch_bits;
277 int new_read_cursor_bit =
278 read_cursor_bit_ +
279 static_cast<signed>(
280 (static_cast<unsigned>(valid_scratch_bits - read_cursor_bit_) & ~7U));
281 new_read_cursor_bit = new_read_cursor_bit == kBlockSizeBits
282 ? static_cast<int>(kBlockSizeBits - 8)
283 : new_read_cursor_bit;
284 read_cursor_ +=
285 static_cast<unsigned>((new_read_cursor_bit - read_cursor_bit_)) / 8;
286 read_cursor_bit_ = new_read_cursor_bit;
287 ReadData();
288 // The current_value is guaranteed to be in [0, 2 * modulus_) range due to the
289 // relationship between modulus_ and bit_width_, and therefore the below
290 // statement guarantees the return value to be in [0, modulus_) range.
291 return current_value < modulus_ ? current_value : current_value - modulus_;
292 }
293
Coder(uint64_t modulus,int bit_width,size_t num_elements)294 SecAggVector::Coder::Coder(uint64_t modulus, int bit_width, size_t num_elements)
295 : modulus_(modulus),
296 bit_width_(bit_width),
297 num_elements_(num_elements),
298 target_cursor_value_(0),
299 starting_bit_position_(0) {
300 num_bytes_needed_ =
301 DivideRoundUp(static_cast<uint32_t>(num_elements_ * bit_width_), 8);
302 // The branchless variant assumes eight bytes of scratch space.
303 // The string is resized to the correct size at the end.
304 packed_bytes_ = std::string(num_bytes_needed_ + 8, '\0');
305 write_cursor_ = &packed_bytes_[0];
306 }
307
WriteValue(uint64_t value)308 void SecAggVector::Coder::WriteValue(uint64_t value) {
309 static constexpr size_t kBlockSize = sizeof(target_cursor_value_);
310 // Here, we use memcpy() to avoid the undefined behavior of
311 // reinterpret_cast<> on unaligned stores, opportunistically writing eight
312 // bytes at a time.
313 target_cursor_value_ &= (1ULL << starting_bit_position_) - 1;
314 target_cursor_value_ |= value << starting_bit_position_;
315 std::memcpy(write_cursor_, &target_cursor_value_, kBlockSize);
316 const auto new_write_cursor =
317 write_cursor_ + (starting_bit_position_ + bit_width_) / 8;
318 const auto new_starting_bit_position =
319 (starting_bit_position_ + bit_width_) % 8;
320 // Because we write in eight byte chunks, a portion of element may have
321 // been missed, and now belongs at the front of target_cursor_value. The
322 // following assignments are compiled to a branchless conditional move on
323 // Clang X86_64 and Clang ARMv{7, 8}.
324 auto runt_cursor_value =
325 new_starting_bit_position
326 ? value >> (static_cast<unsigned>(kBlockSize * 8 -
327 starting_bit_position_) &
328 (kBlockSize * 8 - 1)) // Prevent unused UB warning.
329 : 0;
330 // Otherwise, remove fully written values from our scratch space.
331 target_cursor_value_ >>=
332 (static_cast<unsigned>(new_write_cursor - write_cursor_) * 8) &
333 (kBlockSize * 8 - 1); // Prevent unused UB warning.
334 target_cursor_value_ = (new_write_cursor - write_cursor_ == kBlockSize)
335 ? runt_cursor_value
336 : target_cursor_value_;
337 write_cursor_ = new_write_cursor;
338 starting_bit_position_ = new_starting_bit_position;
339 }
340
Create()341 SecAggVector SecAggVector::Coder::Create() && {
342 static constexpr size_t kBlockSize = sizeof(target_cursor_value_);
343 std::memcpy(write_cursor_, &target_cursor_value_, kBlockSize);
344 packed_bytes_.resize(num_bytes_needed_);
345 return SecAggVector(std::move(packed_bytes_), modulus_, num_elements_,
346 /* branchless_codec=*/true);
347 }
348
Add(const SecAggVector & other)349 void SecAggUnpackedVector::Add(const SecAggVector& other) {
350 FCP_CHECK(num_elements() == other.num_elements());
351 FCP_CHECK(modulus() == other.modulus());
352 SecAggVector::Decoder decoder(other);
353 for (auto& v : *this) {
354 v = AddModOpt(v, decoder.ReadValue(), modulus());
355 }
356 }
357
Add(const SecAggVectorMap & other)358 void SecAggUnpackedVectorMap::Add(const SecAggVectorMap& other) {
359 FCP_CHECK(size() == other.size());
360 for (auto& [name, vector] : *this) {
361 auto it = other.find(name);
362 FCP_CHECK(it != other.end());
363 vector.Add(it->second);
364 }
365 }
366
AddMaps(const SecAggUnpackedVectorMap & a,const SecAggUnpackedVectorMap & b)367 std::unique_ptr<SecAggUnpackedVectorMap> SecAggUnpackedVectorMap::AddMaps(
368 const SecAggUnpackedVectorMap& a, const SecAggUnpackedVectorMap& b) {
369 auto result = std::make_unique<SecAggUnpackedVectorMap>();
370 for (const auto& entry : a) {
371 auto name = entry.first;
372 auto length = entry.second.num_elements();
373 auto modulus = entry.second.modulus();
374 const auto& a_at_name = entry.second;
375 const auto& b_at_name = b.at(name);
376 SecAggUnpackedVector result_vector(length, modulus);
377 for (int j = 0; j < length; ++j) {
378 result_vector[j] = AddModOpt(a_at_name[j], b_at_name[j], modulus);
379 }
380 result->emplace(name, std::move(result_vector));
381 }
382 return result;
383 }
384
385 } // namespace secagg
386 } // namespace fcp
387