xref: /aosp_15_r20/external/pigweed/pw_protobuf/encoder.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_protobuf/encoder.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstring>
20 #include <optional>
21 
22 #include "pw_assert/check.h"
23 #include "pw_bytes/span.h"
24 #include "pw_protobuf/internal/codegen.h"
25 #include "pw_protobuf/serialized_size.h"
26 #include "pw_protobuf/stream_decoder.h"
27 #include "pw_protobuf/wire_format.h"
28 #include "pw_span/span.h"
29 #include "pw_status/status.h"
30 #include "pw_status/try.h"
31 #include "pw_stream/memory_stream.h"
32 #include "pw_stream/stream.h"
33 #include "pw_string/string.h"
34 #include "pw_varint/varint.h"
35 
36 namespace pw::protobuf {
37 
38 using internal::VarintType;
39 
GetNestedEncoder(uint32_t field_number,bool write_when_empty)40 StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number,
41                                               bool write_when_empty) {
42   PW_CHECK(!nested_encoder_open());
43 
44   nested_field_number_ = field_number;
45   if (!ValidFieldNumber(field_number)) {
46     status_.Update(Status::InvalidArgument());
47     return StreamEncoder(*this, ByteSpan(), false);
48   }
49 
50   // Pass the unused space of the scratch buffer to the nested encoder to use
51   // as their scratch buffer.
52   size_t key_size =
53       varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
54   size_t reserved_size = key_size + config::kMaxVarintSize;
55   size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
56                              writer_.ConservativeWriteLimit());
57   // Account for reserved bytes.
58   max_size = max_size > reserved_size ? max_size - reserved_size : 0;
59   // Cap based on max varint size.
60   max_size = std::min(varint::MaxValueInBytes(config::kMaxVarintSize),
61                       static_cast<uint64_t>(max_size));
62 
63   ByteSpan nested_buffer;
64   if (max_size > 0) {
65     nested_buffer = ByteSpan(
66         memory_writer_.data() + reserved_size + memory_writer_.bytes_written(),
67         max_size);
68   } else {
69     nested_buffer = ByteSpan();
70   }
71   return StreamEncoder(*this, nested_buffer, write_when_empty);
72 }
73 
CloseEncoder()74 void StreamEncoder::CloseEncoder() {
75   // If this was an invalidated StreamEncoder which cannot be used, permit the
76   // object to be cleanly destructed by doing nothing.
77   if (nested_field_number_ == kFirstReservedNumber) {
78     return;
79   }
80 
81   PW_CHECK(
82       !nested_encoder_open(),
83       "Tried to destruct a proto encoder with an active submessage encoder");
84 
85   if (parent_ != nullptr) {
86     parent_->CloseNestedMessage(*this);
87   }
88 }
89 
CloseNestedMessage(StreamEncoder & nested)90 void StreamEncoder::CloseNestedMessage(StreamEncoder& nested) {
91   PW_DCHECK_PTR_EQ(nested.parent_,
92                    this,
93                    "CloseNestedMessage() called on the wrong Encoder parent");
94 
95   // Make the nested encoder look like it has an open child to block writes for
96   // the remainder of the object's life.
97   nested.nested_field_number_ = kFirstReservedNumber;
98   nested.parent_ = nullptr;
99   // Temporarily cache the field number of the child so we can re-enable
100   // writing to this encoder.
101   uint32_t temp_field_number = nested_field_number_;
102   nested_field_number_ = 0;
103 
104   // TODO(amontanez): If a submessage fails, we could optionally discard
105   // it and continue happily. For now, we'll always invalidate the entire
106   // encoder if a single submessage fails.
107   status_.Update(nested.status_);
108   if (!status_.ok()) {
109     return;
110   }
111 
112   if (varint::EncodedSize(nested.memory_writer_.bytes_written()) >
113       config::kMaxVarintSize) {
114     status_ = Status::OutOfRange();
115     return;
116   }
117 
118   if (!nested.memory_writer_.bytes_written() && !nested.write_when_empty_) {
119     return;
120   }
121 
122   status_ = WriteLengthDelimitedField(temp_field_number,
123                                       nested.memory_writer_.WrittenData());
124 }
125 
WriteVarintField(uint32_t field_number,uint64_t value)126 Status StreamEncoder::WriteVarintField(uint32_t field_number, uint64_t value) {
127   PW_TRY(UpdateStatusForWrite(
128       field_number, WireType::kVarint, varint::EncodedSize(value)));
129 
130   WriteVarint(FieldKey(field_number, WireType::kVarint))
131       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
132   return WriteVarint(value);
133 }
134 
WriteLengthDelimitedField(uint32_t field_number,ConstByteSpan data)135 Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
136                                                 ConstByteSpan data) {
137   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
138   status_.Update(WriteLengthDelimitedKeyAndLengthPrefix(
139       field_number, data.size(), writer_));
140   PW_TRY(status_);
141   if (Status status = writer_.Write(data); !status.ok()) {
142     status_ = status;
143   }
144   return status_;
145 }
146 
WriteLengthDelimitedFieldFromStream(uint32_t field_number,stream::Reader & bytes_reader,size_t num_bytes,ByteSpan stream_pipe_buffer)147 Status StreamEncoder::WriteLengthDelimitedFieldFromStream(
148     uint32_t field_number,
149     stream::Reader& bytes_reader,
150     size_t num_bytes,
151     ByteSpan stream_pipe_buffer) {
152   PW_CHECK_UINT_GT(
153       stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
154   PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
155   status_.Update(
156       WriteLengthDelimitedKeyAndLengthPrefix(field_number, num_bytes, writer_));
157   PW_TRY(status_);
158 
159   // Stream data from `bytes_reader` to `writer_`.
160   // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
161   // time.
162   for (size_t bytes_written = 0; bytes_written < num_bytes;) {
163     const size_t chunk_size_bytes =
164         std::min(num_bytes - bytes_written, stream_pipe_buffer.size_bytes());
165     const Result<ByteSpan> read_result =
166         bytes_reader.Read(stream_pipe_buffer.data(), chunk_size_bytes);
167     status_.Update(read_result.status());
168     PW_TRY(status_);
169 
170     status_.Update(writer_.Write(read_result.value()));
171     PW_TRY(status_);
172 
173     bytes_written += read_result.value().size();
174   }
175 
176   return OkStatus();
177 }
178 
WriteFixed(uint32_t field_number,ConstByteSpan data)179 Status StreamEncoder::WriteFixed(uint32_t field_number, ConstByteSpan data) {
180   WireType type =
181       data.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
182 
183   PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
184 
185   WriteVarint(FieldKey(field_number, type))
186       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
187   if (Status status = writer_.Write(data); !status.ok()) {
188     status_ = status;
189   }
190   return status_;
191 }
192 
WritePackedFixed(uint32_t field_number,span<const std::byte> values,size_t elem_size)193 Status StreamEncoder::WritePackedFixed(uint32_t field_number,
194                                        span<const std::byte> values,
195                                        size_t elem_size) {
196   if (values.empty()) {
197     return status_;
198   }
199 
200   PW_CHECK_NOTNULL(values.data());
201   PW_DCHECK(elem_size == sizeof(uint32_t) || elem_size == sizeof(uint64_t));
202 
203   PW_TRY(UpdateStatusForWrite(
204       field_number, WireType::kDelimited, values.size_bytes()));
205   WriteVarint(FieldKey(field_number, WireType::kDelimited))
206       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
207   WriteVarint(values.size_bytes())
208       .IgnoreError();  // TODO: b/242598609 - Handle Status properly
209 
210   for (auto val_start = values.begin(); val_start != values.end();
211        val_start += elem_size) {
212     // Allocates 8 bytes so both 4-byte and 8-byte types can be encoded as
213     // little-endian for serialization.
214     std::array<std::byte, sizeof(uint64_t)> data;
215     if (endian::native == endian::little) {
216       std::copy(val_start, val_start + elem_size, std::begin(data));
217     } else {
218       std::reverse_copy(val_start, val_start + elem_size, std::begin(data));
219     }
220     status_.Update(writer_.Write(span(data).first(elem_size)));
221     PW_TRY(status_);
222   }
223   return status_;
224 }
225 
UpdateStatusForWrite(uint32_t field_number,WireType type,size_t data_size)226 Status StreamEncoder::UpdateStatusForWrite(uint32_t field_number,
227                                            WireType type,
228                                            size_t data_size) {
229   PW_CHECK(!nested_encoder_open());
230   PW_TRY(status_);
231 
232   if (!ValidFieldNumber(field_number)) {
233     return status_ = Status::InvalidArgument();
234   }
235 
236   const Result<size_t> field_size = SizeOfField(field_number, type, data_size);
237   status_.Update(field_size.status());
238   PW_TRY(status_);
239 
240   if (field_size.value() > writer_.ConservativeWriteLimit()) {
241     status_ = Status::ResourceExhausted();
242   }
243 
244   return status_;
245 }
246 
Write(span<const std::byte> message,span<const internal::MessageField> table)247 Status StreamEncoder::Write(span<const std::byte> message,
248                             span<const internal::MessageField> table) {
249   PW_CHECK(!nested_encoder_open());
250   PW_TRY(status_);
251 
252   for (const auto& field : table) {
253     // Calculate the span of bytes corresponding to the structure field to
254     // read from.
255     ConstByteSpan values =
256         message.subspan(field.field_offset(), field.field_size());
257     PW_CHECK(values.begin() >= message.begin() &&
258              values.end() <= message.end());
259 
260     // If the field is using callbacks, interpret the input field accordingly
261     // and allow the caller to provide custom handling.
262     if (field.callback_type() == internal::CallbackType::kSingleField) {
263       const Callback<StreamEncoder, StreamDecoder>* callback =
264           reinterpret_cast<const Callback<StreamEncoder, StreamDecoder>*>(
265               values.data());
266       PW_TRY(callback->Encode(*this));
267       continue;
268     } else if (field.callback_type() == internal::CallbackType::kOneOfGroup) {
269       const OneOf<StreamEncoder, StreamDecoder>* callback =
270           reinterpret_cast<const OneOf<StreamEncoder, StreamDecoder>*>(
271               values.data());
272       PW_TRY(callback->Encode(*this));
273       continue;
274     }
275 
276     switch (field.wire_type()) {
277       case WireType::kFixed64:
278       case WireType::kFixed32: {
279         // Fixed fields call WriteFixed() for singular case and
280         // WritePackedFixed() for repeated fields.
281         PW_CHECK(field.elem_size() == (field.wire_type() == WireType::kFixed32
282                                            ? sizeof(uint32_t)
283                                            : sizeof(uint64_t)),
284                  "Mismatched message field type and size");
285         if (field.is_fixed_size()) {
286           PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
287           if (static_cast<size_t>(
288                   std::count(values.begin(), values.end(), std::byte{0})) <
289               values.size()) {
290             PW_TRY(WritePackedFixed(
291                 field.field_number(), values, field.elem_size()));
292           }
293         } else if (field.is_repeated()) {
294           // The struct member for this field is a vector of a type
295           // corresponding to the field element size. Cast to the correct
296           // vector type so we're not performing type aliasing (except for
297           // unsigned vs signed which is explicitly allowed).
298           if (field.elem_size() == sizeof(uint64_t)) {
299             const auto* vector =
300                 reinterpret_cast<const pw::Vector<const uint64_t>*>(
301                     values.data());
302             if (!vector->empty()) {
303               PW_TRY(WritePackedFixed(
304                   field.field_number(),
305                   as_bytes(span(vector->data(), vector->size())),
306                   field.elem_size()));
307             }
308           } else if (field.elem_size() == sizeof(uint32_t)) {
309             const auto* vector =
310                 reinterpret_cast<const pw::Vector<const uint32_t>*>(
311                     values.data());
312             if (!vector->empty()) {
313               PW_TRY(WritePackedFixed(
314                   field.field_number(),
315                   as_bytes(span(vector->data(), vector->size())),
316                   field.elem_size()));
317             }
318           }
319         } else if (field.is_optional()) {
320           // The struct member for this field is a std::optional of a type
321           // corresponding to the field element size. Cast to the correct
322           // optional type so we're not performing type aliasing (except for
323           // unsigned vs signed which is explicitly allowed), and write from
324           // a temporary.
325           if (field.elem_size() == sizeof(uint64_t)) {
326             const auto* optional =
327                 reinterpret_cast<const std::optional<uint64_t>*>(values.data());
328             if (optional->has_value()) {
329               uint64_t value = optional->value();
330               PW_TRY(
331                   WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
332             }
333           } else if (field.elem_size() == sizeof(uint32_t)) {
334             const auto* optional =
335                 reinterpret_cast<const std::optional<uint32_t>*>(values.data());
336             if (optional->has_value()) {
337               uint32_t value = optional->value();
338               PW_TRY(
339                   WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
340             }
341           }
342         } else {
343           PW_CHECK(values.size() == field.elem_size(),
344                    "Mismatched message field type and size");
345           if (static_cast<size_t>(
346                   std::count(values.begin(), values.end(), std::byte{0})) <
347               values.size()) {
348             PW_TRY(WriteFixed(field.field_number(), values));
349           }
350         }
351         break;
352       }
353       case WireType::kVarint: {
354         // Varint fields call WriteVarintField() for singular case and
355         // WritePackedVarints() for repeated fields.
356         PW_CHECK(field.elem_size() == sizeof(uint64_t) ||
357                      field.elem_size() == sizeof(uint32_t) ||
358                      field.elem_size() == sizeof(bool),
359                  "Mismatched message field type and size");
360         if (field.is_fixed_size()) {
361           // The struct member for this field is an array of type corresponding
362           // to the field element size. Cast to a span of the correct type over
363           // the array so we're not performing type aliasing (except for
364           // unsigned vs signed which is explicitly allowed).
365           PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
366           if (static_cast<size_t>(
367                   std::count(values.begin(), values.end(), std::byte{0})) ==
368               values.size()) {
369             continue;
370           }
371           if (field.elem_size() == sizeof(uint64_t)) {
372             PW_TRY(WritePackedVarints(
373                 field.field_number(),
374                 span(reinterpret_cast<const uint64_t*>(values.data()),
375                      values.size() / field.elem_size()),
376                 field.varint_type()));
377           } else if (field.elem_size() == sizeof(uint32_t)) {
378             PW_TRY(WritePackedVarints(
379                 field.field_number(),
380                 span(reinterpret_cast<const uint32_t*>(values.data()),
381                      values.size() / field.elem_size()),
382                 field.varint_type()));
383           } else if (field.elem_size() == sizeof(bool)) {
384             static_assert(sizeof(bool) == sizeof(uint8_t),
385                           "bool must be same size as uint8_t");
386             PW_TRY(WritePackedVarints(
387                 field.field_number(),
388                 span(reinterpret_cast<const uint8_t*>(values.data()),
389                      values.size() / field.elem_size()),
390                 field.varint_type()));
391           }
392         } else if (field.is_repeated()) {
393           // The struct member for this field is a vector of a type
394           // corresponding to the field element size. Cast to the correct
395           // vector type so we're not performing type aliasing (except for
396           // unsigned vs signed which is explicitly allowed).
397           if (field.elem_size() == sizeof(uint64_t)) {
398             const auto* vector =
399                 reinterpret_cast<const pw::Vector<const uint64_t>*>(
400                     values.data());
401             if (!vector->empty()) {
402               PW_TRY(WritePackedVarints(field.field_number(),
403                                         span(vector->data(), vector->size()),
404                                         field.varint_type()));
405             }
406           } else if (field.elem_size() == sizeof(uint32_t)) {
407             const auto* vector =
408                 reinterpret_cast<const pw::Vector<const uint32_t>*>(
409                     values.data());
410             if (!vector->empty()) {
411               PW_TRY(WritePackedVarints(field.field_number(),
412                                         span(vector->data(), vector->size()),
413                                         field.varint_type()));
414             }
415           } else if (field.elem_size() == sizeof(bool)) {
416             static_assert(sizeof(bool) == sizeof(uint8_t),
417                           "bool must be same size as uint8_t");
418             const auto* vector =
419                 reinterpret_cast<const pw::Vector<const uint8_t>*>(
420                     values.data());
421             if (!vector->empty()) {
422               PW_TRY(WritePackedVarints(field.field_number(),
423                                         span(vector->data(), vector->size()),
424                                         field.varint_type()));
425             }
426           }
427         } else if (field.is_optional()) {
428           // The struct member for this field is a std::optional of a type
429           // corresponding to the field element size. Cast to the correct
430           // optional type so we're not performing type aliasing (except for
431           // unsigned vs signed which is explicitly allowed), and write from
432           // a temporary.
433           uint64_t value = 0;
434           if (field.elem_size() == sizeof(uint64_t)) {
435             if (field.varint_type() == VarintType::kUnsigned) {
436               const auto* optional =
437                   reinterpret_cast<const std::optional<uint64_t>*>(
438                       values.data());
439               if (!optional->has_value()) {
440                 continue;
441               }
442               value = optional->value();
443             } else {
444               const auto* optional =
445                   reinterpret_cast<const std::optional<int64_t>*>(
446                       values.data());
447               if (!optional->has_value()) {
448                 continue;
449               }
450               value = field.varint_type() == VarintType::kZigZag
451                           ? varint::ZigZagEncode(optional->value())
452                           : optional->value();
453             }
454           } else if (field.elem_size() == sizeof(uint32_t)) {
455             if (field.varint_type() == VarintType::kUnsigned) {
456               const auto* optional =
457                   reinterpret_cast<const std::optional<uint32_t>*>(
458                       values.data());
459               if (!optional->has_value()) {
460                 continue;
461               }
462               value = optional->value();
463             } else {
464               const auto* optional =
465                   reinterpret_cast<const std::optional<int32_t>*>(
466                       values.data());
467               if (!optional->has_value()) {
468                 continue;
469               }
470               value = field.varint_type() == VarintType::kZigZag
471                           ? varint::ZigZagEncode(optional->value())
472                           : optional->value();
473             }
474           } else if (field.elem_size() == sizeof(bool)) {
475             const auto* optional =
476                 reinterpret_cast<const std::optional<bool>*>(values.data());
477             if (!optional->has_value()) {
478               continue;
479             }
480             value = optional->value();
481           }
482           PW_TRY(WriteVarintField(field.field_number(), value));
483         } else {
484           // The struct member for this field is a scalar of a type
485           // corresponding to the field element size. Cast to the correct
486           // type to retrieve the value before passing to WriteVarintField()
487           // so we're not performing type aliasing (except for unsigned vs
488           // signed which is explicitly allowed).
489           PW_CHECK(values.size() == field.elem_size(),
490                    "Mismatched message field type and size");
491           uint64_t value = 0;
492           if (field.elem_size() == sizeof(uint64_t)) {
493             if (field.varint_type() == VarintType::kZigZag) {
494               value = varint::ZigZagEncode(
495                   *reinterpret_cast<const int64_t*>(values.data()));
496             } else if (field.varint_type() == VarintType::kNormal) {
497               value = *reinterpret_cast<const int64_t*>(values.data());
498             } else {
499               value = *reinterpret_cast<const uint64_t*>(values.data());
500             }
501             if (!value) {
502               continue;
503             }
504           } else if (field.elem_size() == sizeof(uint32_t)) {
505             if (field.varint_type() == VarintType::kZigZag) {
506               value = varint::ZigZagEncode(
507                   *reinterpret_cast<const int32_t*>(values.data()));
508             } else if (field.varint_type() == VarintType::kNormal) {
509               value = *reinterpret_cast<const int32_t*>(values.data());
510             } else {
511               value = *reinterpret_cast<const uint32_t*>(values.data());
512             }
513             if (!value) {
514               continue;
515             }
516           } else if (field.elem_size() == sizeof(bool)) {
517             value = *reinterpret_cast<const bool*>(values.data());
518             if (!value) {
519               continue;
520             }
521           }
522           PW_TRY(WriteVarintField(field.field_number(), value));
523         }
524         break;
525       }
526       case WireType::kDelimited: {
527         // Delimited fields are always a singular case because of the
528         // inability to cast to a generic vector with an element of a certain
529         // size (we always need a type).
530         PW_CHECK(!field.is_repeated(),
531                  "Repeated delimited messages always require a callback");
532         if (field.nested_message_fields()) {
533           // Nested Message. Struct member is an embedded struct for the
534           // nested field. Obtain a nested encoder and recursively call Write()
535           // using the fields table pointer from this field.
536           auto nested_encoder = GetNestedEncoder(field.field_number(),
537                                                  /*write_when_empty=*/false);
538           PW_TRY(nested_encoder.Write(values, *field.nested_message_fields()));
539         } else if (field.is_fixed_size()) {
540           // Fixed-length bytes field. Struct member is a std::array<std::byte>.
541           // Call WriteLengthDelimitedField() to output it to the stream.
542           PW_CHECK(field.elem_size() == sizeof(std::byte),
543                    "Mismatched message field type and size");
544           if (static_cast<size_t>(
545                   std::count(values.begin(), values.end(), std::byte{0})) <
546               values.size()) {
547             PW_TRY(WriteLengthDelimitedField(field.field_number(), values));
548           }
549         } else {
550           // bytes or string field with a maximum size. Struct member is
551           // pw::Vector<std::byte> for bytes or pw::InlineString<> for string.
552           // Use the contents as a span and call WriteLengthDelimitedField() to
553           // output it to the stream.
554           PW_CHECK(field.elem_size() == sizeof(std::byte),
555                    "Mismatched message field type and size");
556           if (field.is_string()) {
557             PW_TRY(WriteStringOrBytes<const InlineString<>>(
558                 field.field_number(), values.data()));
559           } else {
560             PW_TRY(WriteStringOrBytes<const Vector<const std::byte>>(
561                 field.field_number(), values.data()));
562           }
563         }
564         break;
565       }
566     }
567   }
568 
569   ResetOneOfCallbacks(message, table);
570 
571   return status_;
572 }
573 
ResetOneOfCallbacks(ConstByteSpan message,span<const internal::MessageField> table)574 void StreamEncoder::ResetOneOfCallbacks(
575     ConstByteSpan message, span<const internal::MessageField> table) {
576   for (const auto& field : table) {
577     // Calculate the span of bytes corresponding to the structure field to
578     // read from.
579     ConstByteSpan values =
580         message.subspan(field.field_offset(), field.field_size());
581     PW_CHECK(values.begin() >= message.begin() &&
582              values.end() <= message.end());
583 
584     if (field.callback_type() == internal::CallbackType::kOneOfGroup) {
585       const OneOf<StreamEncoder, StreamDecoder>* callback =
586           reinterpret_cast<const OneOf<StreamEncoder, StreamDecoder>*>(
587               values.data());
588       callback->invoked_ = false;
589     }
590   }
591 }
592 
593 }  // namespace pw::protobuf
594