xref: /aosp_15_r20/external/federated-compute/fcp/secagg/shared/secagg_vector.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/shared/secagg_vector.h"
18 
19 #include <inttypes.h>
20 
21 #include <algorithm>
22 #include <array>
23 #include <climits>
24 #include <cstdint>
25 #include <cstring>
26 #include <memory>
27 #include <string>
28 #include <utility>
29 #include <vector>
30 
31 #include "absl/strings/string_view.h"
32 #include "absl/types/span.h"
33 #include "fcp/base/monitoring.h"
34 #include "fcp/secagg/shared/math.h"
35 
36 namespace fcp {
37 namespace secagg {
38 
39 const uint64_t SecAggVector::kMaxModulus;
40 
SecAggVector(absl::Span<const uint64_t> span,uint64_t modulus,bool branchless_codec)41 SecAggVector::SecAggVector(absl::Span<const uint64_t> span, uint64_t modulus,
42                            bool branchless_codec)
43     : modulus_(modulus),
44       bit_width_(SecAggVector::GetBitWidth(modulus)),
45       num_elements_(span.size()),
46       branchless_codec_(branchless_codec) {
47   FCP_CHECK(modulus_ > 1 && modulus_ <= kMaxModulus)
48       << "The specified modulus is not valid: must be > 1 and <= "
49       << kMaxModulus << "; supplied value : " << modulus_;
50   // Ensuring the supplied vector has the appropriate modulus.
51   for (uint64_t element : span) {
52     FCP_CHECK(element >= 0)
53         << "Only non negative elements are allowed in the vector.";
54     FCP_CHECK(element < modulus_)
55         << "The span does not have the appropriate modulus: element "
56            "with value "
57         << element << " found, max value allowed " << (modulus_ - 1ULL);
58   }
59 
60   // Packs the long vector into a string, initialized to all null.
61   if (branchless_codec_) {
62     PackUint64IntoByteStringBranchless(span);
63   } else {
64     int num_bytes_needed =
65         DivideRoundUp(static_cast<uint32_t>(num_elements_ * bit_width_), 8);
66     packed_bytes_ = std::string(num_bytes_needed, '\0');
67     for (int i = 0; static_cast<size_t>(i) < span.size(); ++i) {
68       PackUint64IntoByteStringAt(i, span[i]);
69     }
70   }
71 }
72 
SecAggVector(std::string packed_bytes,uint64_t modulus,size_t num_elements,bool branchless_codec)73 SecAggVector::SecAggVector(std::string packed_bytes, uint64_t modulus,
74                            size_t num_elements, bool branchless_codec)
75     : packed_bytes_(std::move(packed_bytes)),
76       modulus_(modulus),
77       bit_width_(SecAggVector::GetBitWidth(modulus)),
78       num_elements_(num_elements),
79       branchless_codec_(branchless_codec) {
80   FCP_CHECK(modulus_ > 1 && modulus_ <= kMaxModulus)
81       << "The specified modulus is not valid: must be > 1 and <= "
82       << kMaxModulus << "; supplied value : " << modulus_;
83   int expected_num_bytes = DivideRoundUp(num_elements_ * bit_width_, 8);
84   FCP_CHECK(packed_bytes_.size() == static_cast<size_t>(expected_num_bytes))
85       << "The supplied string is not the right size for " << num_elements_
86       << " packed elements: given string has a limit of "
87       << packed_bytes_.size() << " bytes, " << expected_num_bytes
88       << " bytes would have been needed.";
89 }
90 
GetAsUint64Vector() const91 std::vector<uint64_t> SecAggVector::GetAsUint64Vector() const {
92   CheckHasValue();
93   std::vector<uint64_t> long_vector;
94   if (branchless_codec_) {
95     UnpackByteStringToUint64VectorBranchless(&long_vector);
96   } else {
97     long_vector.reserve(num_elements_);
98     for (int i = 0; i < num_elements_; ++i) {
99       long_vector.push_back(
100           UnpackUint64FromByteStringAt(i, bit_width_, packed_bytes_));
101     }
102   }
103   return long_vector;
104 }
105 
PackUint64IntoByteStringAt(int index,uint64_t element)106 void SecAggVector::PackUint64IntoByteStringAt(int index, uint64_t element) {
107   // The element will be packed starting with the least significant (leftmost)
108   // bits.
109   //
110   // TODO(team): Optimize out this extra per element computation.
111   int leftmost_bit_position = index * bit_width_;
112   int current_byte_index = leftmost_bit_position / 8;
113   int bits_left_to_pack = bit_width_;
114 
115   // If leftmost_bit_position is in the middle of a byte, first fill that byte.
116   if (leftmost_bit_position % 8 != 0) {
117     int starting_bit_position = leftmost_bit_position % 8;
118     int empty_bits_left = 8 - starting_bit_position;
119     // Extract enough bits from "element" to fill the current byte, and shift
120     // them to the correct position.
121     uint64_t mask = (1ULL << std::min(empty_bits_left, bits_left_to_pack)) - 1L;
122     uint64_t value_to_add = (element & mask) << starting_bit_position;
123     packed_bytes_[current_byte_index] |= static_cast<char>(value_to_add);
124 
125     bits_left_to_pack -= empty_bits_left;
126     element >>= empty_bits_left;
127     current_byte_index++;
128   }
129 
130   // Current bit position is now aligned with the start of the current byte.
131   // Pack as many whole bytes as possible.
132   uint64_t lower_eight_bit_mask = 255L;
133   while (bits_left_to_pack >= 8) {
134     packed_bytes_[current_byte_index] =
135         static_cast<char>(element & lower_eight_bit_mask);
136 
137     bits_left_to_pack -= 8;
138     element >>= 8;
139     current_byte_index++;
140   }
141 
142   // Pack the remaining partial byte, if necessary.
143   if (bits_left_to_pack > 0) {
144     // there should be < 8 bits left, so pack all remaining bits at once.
145     packed_bytes_[current_byte_index] |= static_cast<char>(element);
146   }
147 }
148 
UnpackUint64FromByteStringAt(int index,int bit_width,const std::string & byte_string)149 uint64_t SecAggVector::UnpackUint64FromByteStringAt(
150     int index, int bit_width, const std::string& byte_string) {
151   // all the bits starting from, and including, this bit are copied.
152   int leftmost_bit_position = index * bit_width;
153   // byte containing the lowest order bit to be copied
154   int leftmost_byte_index = leftmost_bit_position / 8;
155   // all bits up to, but not including this bit, are copied.
156   int right_boundary_bit_position = ((index + 1) * bit_width);
157   // byte containing the highest order bit to copy
158   int rightmost_byte_index = (right_boundary_bit_position - 1) / 8;
159 
160   // Special case: when the entire long value to unpack is contained in a single
161   // byte, then extract that long value in a single step.
162   if (leftmost_byte_index == rightmost_byte_index) {
163     int num_bits_to_skip = (leftmost_bit_position % 8);
164     int mask = ((1 << bit_width) - 1) << num_bits_to_skip;
165     // drop the extraneous bits below and above the value to unpack.
166     uint64_t unpacked_element =
167         (byte_string[leftmost_byte_index] & mask) >> num_bits_to_skip;
168     return unpacked_element;
169   }
170 
171   // Normal case: the value to unpack spans one or more byte boundaries.
172   // The element will be unpacked in reverse order, starting from the most
173   // significant (rightmost) bits.
174   int current_byte_index = rightmost_byte_index;
175   uint64_t unpacked_element = 0;
176   int bits_left_to_unpack = bit_width;
177 
178   // If right_boundary_bit_position is in the middle of a byte, unpack the bits
179   // up to right_boundary_bit_position within that byte.
180   if (right_boundary_bit_position % 8 != 0) {
181     int bits_to_copy_from_current_byte = (right_boundary_bit_position % 8);
182     int lower_bits_mask = (1 << bits_to_copy_from_current_byte) - 1;
183     unpacked_element |= (byte_string[current_byte_index] & lower_bits_mask);
184 
185     bits_left_to_unpack -= bits_to_copy_from_current_byte;
186     current_byte_index--;
187   }
188 
189   // Current bit position is now aligned with a byte boundary. Unpack as many
190   // whole bytes as possible.
191   while (bits_left_to_unpack >= 8) {
192     unpacked_element <<= 8;
193     unpacked_element |= byte_string[current_byte_index] & 0xff;
194 
195     bits_left_to_unpack -= 8;
196     current_byte_index--;
197   }
198 
199   // Unpack the remaining partial byte, if necessary.
200   if (bits_left_to_unpack > 0) {
201     unpacked_element <<= bits_left_to_unpack;
202     int bits_to_skip_in_current_byte = 8 - bits_left_to_unpack;
203     unpacked_element |= (byte_string[current_byte_index] & 0xff) >>
204                         bits_to_skip_in_current_byte;
205   }
206 
207   return unpacked_element;
208 }
209 
PackUint64IntoByteStringBranchless(const absl::Span<const uint64_t> span)210 void SecAggVector::PackUint64IntoByteStringBranchless(
211     const absl::Span<const uint64_t> span) {
212   SecAggVector::Coder coder(modulus_, bit_width_, num_elements_);
213   for (uint64_t element : span) {
214     coder.WriteValue(element);
215   }
216   packed_bytes_ = std::move(coder).Create().TakePackedBytes();
217 }
218 
UnpackByteStringToUint64VectorBranchless(std::vector<uint64_t> * long_vector) const219 void SecAggVector::UnpackByteStringToUint64VectorBranchless(
220     std::vector<uint64_t>* long_vector) const {
221   long_vector->resize(num_elements_);
222   Decoder decoder(*this);
223   for (uint64_t& element : *long_vector) {
224     element = decoder.ReadValue();
225   }
226 }
227 
Decoder(absl::string_view packed_bytes,uint64_t modulus)228 SecAggVector::Decoder::Decoder(absl::string_view packed_bytes, uint64_t modulus)
229     : read_cursor_(packed_bytes.data()),
230       cursor_sentinel_(packed_bytes.data() + packed_bytes.size()),
231       cursor_read_value_(0),
232       scratch_(0),
233       read_cursor_bit_(0),
234       bit_width_(SecAggVector::GetBitWidth(modulus)),
235       mask_((1ULL << bit_width_) - 1),
236       modulus_(modulus) {
237   ReadData();
238 }
239 
ReadData()240 inline void SecAggVector::Decoder::ReadData() {
241   static constexpr ssize_t kBlockSizeBytes = sizeof(cursor_read_value_);
242   const ptrdiff_t bytes_remaining = cursor_sentinel_ - read_cursor_;
243   // Here, we use memcpy() to avoid the undefined behavior of
244   // reinterpret_cast<> on unaligned reads, opportunistically reading up to
245   // eight bytes at a time.
246   if (bytes_remaining >= kBlockSizeBytes) {
247     memcpy(&cursor_read_value_, read_cursor_, kBlockSizeBytes);
248   } else {
249     memcpy(&cursor_read_value_, read_cursor_,
250            bytes_remaining > 0 ? bytes_remaining : 0);
251   }
252   scratch_ |= cursor_read_value_ << static_cast<unsigned>(read_cursor_bit_);
253 }
254 
ReadValue()255 uint64_t SecAggVector::Decoder::ReadValue() {
256   static constexpr int kBlockSizeBits = sizeof(cursor_read_value_) * 8;
257   // Get the current value.
258   const uint64_t current_value = scratch_ & mask_;
259   // Advance to the next value.
260   scratch_ >>= bit_width_;
261   int unwritten_bits = read_cursor_bit_;
262   read_cursor_bit_ -= bit_width_;
263   // Because we read in eight byte chunks on byte boundaries, and only keep
264   // eight bytes of scratch, a portion of the read could not fit, and now
265   // belongs at the back of scratch. The following assignments are compiled
266   // to a branchless conditional move on Clang X86_64 and Clang ARMv{7, 8}.
267   int read_bit_shift = bit_width_ - unwritten_bits;
268   unsigned int right_shift_value = read_bit_shift > 0 ? read_bit_shift : 0;
269   unsigned int left_shift_value = read_bit_shift < 0 ? -read_bit_shift : 0;
270   cursor_read_value_ >>= right_shift_value;
271   cursor_read_value_ <<= left_shift_value;
272   scratch_ |= cursor_read_value_;
273   int valid_scratch_bits = kBlockSizeBits - bit_width_ + unwritten_bits;
274   valid_scratch_bits = (valid_scratch_bits > kBlockSizeBits)
275                            ? kBlockSizeBits
276                            : valid_scratch_bits;
277   int new_read_cursor_bit =
278       read_cursor_bit_ +
279       static_cast<signed>(
280           (static_cast<unsigned>(valid_scratch_bits - read_cursor_bit_) & ~7U));
281   new_read_cursor_bit = new_read_cursor_bit == kBlockSizeBits
282                             ? static_cast<int>(kBlockSizeBits - 8)
283                             : new_read_cursor_bit;
284   read_cursor_ +=
285       static_cast<unsigned>((new_read_cursor_bit - read_cursor_bit_)) / 8;
286   read_cursor_bit_ = new_read_cursor_bit;
287   ReadData();
288   // The current_value is guaranteed to be in [0, 2 * modulus_) range due to the
289   // relationship between modulus_ and bit_width_, and therefore the below
290   // statement guarantees the return value to be in [0, modulus_) range.
291   return current_value < modulus_ ? current_value : current_value - modulus_;
292 }
293 
Coder(uint64_t modulus,int bit_width,size_t num_elements)294 SecAggVector::Coder::Coder(uint64_t modulus, int bit_width, size_t num_elements)
295     : modulus_(modulus),
296       bit_width_(bit_width),
297       num_elements_(num_elements),
298       target_cursor_value_(0),
299       starting_bit_position_(0) {
300   num_bytes_needed_ =
301       DivideRoundUp(static_cast<uint32_t>(num_elements_ * bit_width_), 8);
302   // The branchless variant assumes eight bytes of scratch space.
303   // The string is resized to the correct size at the end.
304   packed_bytes_ = std::string(num_bytes_needed_ + 8, '\0');
305   write_cursor_ = &packed_bytes_[0];
306 }
307 
WriteValue(uint64_t value)308 void SecAggVector::Coder::WriteValue(uint64_t value) {
309   static constexpr size_t kBlockSize = sizeof(target_cursor_value_);
310   // Here, we use memcpy() to avoid the undefined behavior of
311   // reinterpret_cast<> on unaligned stores, opportunistically writing eight
312   // bytes at a time.
313   target_cursor_value_ &= (1ULL << starting_bit_position_) - 1;
314   target_cursor_value_ |= value << starting_bit_position_;
315   std::memcpy(write_cursor_, &target_cursor_value_, kBlockSize);
316   const auto new_write_cursor =
317       write_cursor_ + (starting_bit_position_ + bit_width_) / 8;
318   const auto new_starting_bit_position =
319       (starting_bit_position_ + bit_width_) % 8;
320   // Because we write in eight byte chunks, a portion of element may have
321   // been missed, and now belongs at the front of target_cursor_value. The
322   // following assignments are compiled to a branchless conditional move on
323   // Clang X86_64 and Clang ARMv{7, 8}.
324   auto runt_cursor_value =
325       new_starting_bit_position
326           ? value >> (static_cast<unsigned>(kBlockSize * 8 -
327                                             starting_bit_position_) &
328                       (kBlockSize * 8 - 1))  // Prevent unused UB warning.
329           : 0;
330   // Otherwise, remove fully written values from our scratch space.
331   target_cursor_value_ >>=
332       (static_cast<unsigned>(new_write_cursor - write_cursor_) * 8) &
333       (kBlockSize * 8 - 1);  // Prevent unused UB warning.
334   target_cursor_value_ = (new_write_cursor - write_cursor_ == kBlockSize)
335                              ? runt_cursor_value
336                              : target_cursor_value_;
337   write_cursor_ = new_write_cursor;
338   starting_bit_position_ = new_starting_bit_position;
339 }
340 
Create()341 SecAggVector SecAggVector::Coder::Create() && {
342   static constexpr size_t kBlockSize = sizeof(target_cursor_value_);
343   std::memcpy(write_cursor_, &target_cursor_value_, kBlockSize);
344   packed_bytes_.resize(num_bytes_needed_);
345   return SecAggVector(std::move(packed_bytes_), modulus_, num_elements_,
346                       /* branchless_codec=*/true);
347 }
348 
Add(const SecAggVector & other)349 void SecAggUnpackedVector::Add(const SecAggVector& other) {
350   FCP_CHECK(num_elements() == other.num_elements());
351   FCP_CHECK(modulus() == other.modulus());
352   SecAggVector::Decoder decoder(other);
353   for (auto& v : *this) {
354     v = AddModOpt(v, decoder.ReadValue(), modulus());
355   }
356 }
357 
Add(const SecAggVectorMap & other)358 void SecAggUnpackedVectorMap::Add(const SecAggVectorMap& other) {
359   FCP_CHECK(size() == other.size());
360   for (auto& [name, vector] : *this) {
361     auto it = other.find(name);
362     FCP_CHECK(it != other.end());
363     vector.Add(it->second);
364   }
365 }
366 
AddMaps(const SecAggUnpackedVectorMap & a,const SecAggUnpackedVectorMap & b)367 std::unique_ptr<SecAggUnpackedVectorMap> SecAggUnpackedVectorMap::AddMaps(
368     const SecAggUnpackedVectorMap& a, const SecAggUnpackedVectorMap& b) {
369   auto result = std::make_unique<SecAggUnpackedVectorMap>();
370   for (const auto& entry : a) {
371     auto name = entry.first;
372     auto length = entry.second.num_elements();
373     auto modulus = entry.second.modulus();
374     const auto& a_at_name = entry.second;
375     const auto& b_at_name = b.at(name);
376     SecAggUnpackedVector result_vector(length, modulus);
377     for (int j = 0; j < length; ++j) {
378       result_vector[j] = AddModOpt(a_at_name[j], b_at_name[j], modulus);
379     }
380     result->emplace(name, std::move(result_vector));
381   }
382   return result;
383 }
384 
385 }  // namespace secagg
386 }  // namespace fcp
387