1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_protobuf/encoder.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <cstring>
20 #include <optional>
21
22 #include "pw_assert/check.h"
23 #include "pw_bytes/span.h"
24 #include "pw_protobuf/internal/codegen.h"
25 #include "pw_protobuf/serialized_size.h"
26 #include "pw_protobuf/stream_decoder.h"
27 #include "pw_protobuf/wire_format.h"
28 #include "pw_span/span.h"
29 #include "pw_status/status.h"
30 #include "pw_status/try.h"
31 #include "pw_stream/memory_stream.h"
32 #include "pw_stream/stream.h"
33 #include "pw_string/string.h"
34 #include "pw_varint/varint.h"
35
36 namespace pw::protobuf {
37
38 using internal::VarintType;
39
GetNestedEncoder(uint32_t field_number,bool write_when_empty)40 StreamEncoder StreamEncoder::GetNestedEncoder(uint32_t field_number,
41 bool write_when_empty) {
42 PW_CHECK(!nested_encoder_open());
43
44 nested_field_number_ = field_number;
45 if (!ValidFieldNumber(field_number)) {
46 status_.Update(Status::InvalidArgument());
47 return StreamEncoder(*this, ByteSpan(), false);
48 }
49
50 // Pass the unused space of the scratch buffer to the nested encoder to use
51 // as their scratch buffer.
52 size_t key_size =
53 varint::EncodedSize(FieldKey(field_number, WireType::kDelimited));
54 size_t reserved_size = key_size + config::kMaxVarintSize;
55 size_t max_size = std::min(memory_writer_.ConservativeWriteLimit(),
56 writer_.ConservativeWriteLimit());
57 // Account for reserved bytes.
58 max_size = max_size > reserved_size ? max_size - reserved_size : 0;
59 // Cap based on max varint size.
60 max_size = std::min(varint::MaxValueInBytes(config::kMaxVarintSize),
61 static_cast<uint64_t>(max_size));
62
63 ByteSpan nested_buffer;
64 if (max_size > 0) {
65 nested_buffer = ByteSpan(
66 memory_writer_.data() + reserved_size + memory_writer_.bytes_written(),
67 max_size);
68 } else {
69 nested_buffer = ByteSpan();
70 }
71 return StreamEncoder(*this, nested_buffer, write_when_empty);
72 }
73
CloseEncoder()74 void StreamEncoder::CloseEncoder() {
75 // If this was an invalidated StreamEncoder which cannot be used, permit the
76 // object to be cleanly destructed by doing nothing.
77 if (nested_field_number_ == kFirstReservedNumber) {
78 return;
79 }
80
81 PW_CHECK(
82 !nested_encoder_open(),
83 "Tried to destruct a proto encoder with an active submessage encoder");
84
85 if (parent_ != nullptr) {
86 parent_->CloseNestedMessage(*this);
87 }
88 }
89
CloseNestedMessage(StreamEncoder & nested)90 void StreamEncoder::CloseNestedMessage(StreamEncoder& nested) {
91 PW_DCHECK_PTR_EQ(nested.parent_,
92 this,
93 "CloseNestedMessage() called on the wrong Encoder parent");
94
95 // Make the nested encoder look like it has an open child to block writes for
96 // the remainder of the object's life.
97 nested.nested_field_number_ = kFirstReservedNumber;
98 nested.parent_ = nullptr;
99 // Temporarily cache the field number of the child so we can re-enable
100 // writing to this encoder.
101 uint32_t temp_field_number = nested_field_number_;
102 nested_field_number_ = 0;
103
104 // TODO(amontanez): If a submessage fails, we could optionally discard
105 // it and continue happily. For now, we'll always invalidate the entire
106 // encoder if a single submessage fails.
107 status_.Update(nested.status_);
108 if (!status_.ok()) {
109 return;
110 }
111
112 if (varint::EncodedSize(nested.memory_writer_.bytes_written()) >
113 config::kMaxVarintSize) {
114 status_ = Status::OutOfRange();
115 return;
116 }
117
118 if (!nested.memory_writer_.bytes_written() && !nested.write_when_empty_) {
119 return;
120 }
121
122 status_ = WriteLengthDelimitedField(temp_field_number,
123 nested.memory_writer_.WrittenData());
124 }
125
WriteVarintField(uint32_t field_number,uint64_t value)126 Status StreamEncoder::WriteVarintField(uint32_t field_number, uint64_t value) {
127 PW_TRY(UpdateStatusForWrite(
128 field_number, WireType::kVarint, varint::EncodedSize(value)));
129
130 WriteVarint(FieldKey(field_number, WireType::kVarint))
131 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
132 return WriteVarint(value);
133 }
134
WriteLengthDelimitedField(uint32_t field_number,ConstByteSpan data)135 Status StreamEncoder::WriteLengthDelimitedField(uint32_t field_number,
136 ConstByteSpan data) {
137 PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, data.size()));
138 status_.Update(WriteLengthDelimitedKeyAndLengthPrefix(
139 field_number, data.size(), writer_));
140 PW_TRY(status_);
141 if (Status status = writer_.Write(data); !status.ok()) {
142 status_ = status;
143 }
144 return status_;
145 }
146
WriteLengthDelimitedFieldFromStream(uint32_t field_number,stream::Reader & bytes_reader,size_t num_bytes,ByteSpan stream_pipe_buffer)147 Status StreamEncoder::WriteLengthDelimitedFieldFromStream(
148 uint32_t field_number,
149 stream::Reader& bytes_reader,
150 size_t num_bytes,
151 ByteSpan stream_pipe_buffer) {
152 PW_CHECK_UINT_GT(
153 stream_pipe_buffer.size(), 0, "Transfer buffer cannot be 0 size");
154 PW_TRY(UpdateStatusForWrite(field_number, WireType::kDelimited, num_bytes));
155 status_.Update(
156 WriteLengthDelimitedKeyAndLengthPrefix(field_number, num_bytes, writer_));
157 PW_TRY(status_);
158
159 // Stream data from `bytes_reader` to `writer_`.
160 // TODO(pwbug/468): move the following logic to pw_stream/copy.h at a later
161 // time.
162 for (size_t bytes_written = 0; bytes_written < num_bytes;) {
163 const size_t chunk_size_bytes =
164 std::min(num_bytes - bytes_written, stream_pipe_buffer.size_bytes());
165 const Result<ByteSpan> read_result =
166 bytes_reader.Read(stream_pipe_buffer.data(), chunk_size_bytes);
167 status_.Update(read_result.status());
168 PW_TRY(status_);
169
170 status_.Update(writer_.Write(read_result.value()));
171 PW_TRY(status_);
172
173 bytes_written += read_result.value().size();
174 }
175
176 return OkStatus();
177 }
178
WriteFixed(uint32_t field_number,ConstByteSpan data)179 Status StreamEncoder::WriteFixed(uint32_t field_number, ConstByteSpan data) {
180 WireType type =
181 data.size() == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
182
183 PW_TRY(UpdateStatusForWrite(field_number, type, data.size()));
184
185 WriteVarint(FieldKey(field_number, type))
186 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
187 if (Status status = writer_.Write(data); !status.ok()) {
188 status_ = status;
189 }
190 return status_;
191 }
192
WritePackedFixed(uint32_t field_number,span<const std::byte> values,size_t elem_size)193 Status StreamEncoder::WritePackedFixed(uint32_t field_number,
194 span<const std::byte> values,
195 size_t elem_size) {
196 if (values.empty()) {
197 return status_;
198 }
199
200 PW_CHECK_NOTNULL(values.data());
201 PW_DCHECK(elem_size == sizeof(uint32_t) || elem_size == sizeof(uint64_t));
202
203 PW_TRY(UpdateStatusForWrite(
204 field_number, WireType::kDelimited, values.size_bytes()));
205 WriteVarint(FieldKey(field_number, WireType::kDelimited))
206 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
207 WriteVarint(values.size_bytes())
208 .IgnoreError(); // TODO: b/242598609 - Handle Status properly
209
210 for (auto val_start = values.begin(); val_start != values.end();
211 val_start += elem_size) {
212 // Allocates 8 bytes so both 4-byte and 8-byte types can be encoded as
213 // little-endian for serialization.
214 std::array<std::byte, sizeof(uint64_t)> data;
215 if (endian::native == endian::little) {
216 std::copy(val_start, val_start + elem_size, std::begin(data));
217 } else {
218 std::reverse_copy(val_start, val_start + elem_size, std::begin(data));
219 }
220 status_.Update(writer_.Write(span(data).first(elem_size)));
221 PW_TRY(status_);
222 }
223 return status_;
224 }
225
UpdateStatusForWrite(uint32_t field_number,WireType type,size_t data_size)226 Status StreamEncoder::UpdateStatusForWrite(uint32_t field_number,
227 WireType type,
228 size_t data_size) {
229 PW_CHECK(!nested_encoder_open());
230 PW_TRY(status_);
231
232 if (!ValidFieldNumber(field_number)) {
233 return status_ = Status::InvalidArgument();
234 }
235
236 const Result<size_t> field_size = SizeOfField(field_number, type, data_size);
237 status_.Update(field_size.status());
238 PW_TRY(status_);
239
240 if (field_size.value() > writer_.ConservativeWriteLimit()) {
241 status_ = Status::ResourceExhausted();
242 }
243
244 return status_;
245 }
246
Write(span<const std::byte> message,span<const internal::MessageField> table)247 Status StreamEncoder::Write(span<const std::byte> message,
248 span<const internal::MessageField> table) {
249 PW_CHECK(!nested_encoder_open());
250 PW_TRY(status_);
251
252 for (const auto& field : table) {
253 // Calculate the span of bytes corresponding to the structure field to
254 // read from.
255 ConstByteSpan values =
256 message.subspan(field.field_offset(), field.field_size());
257 PW_CHECK(values.begin() >= message.begin() &&
258 values.end() <= message.end());
259
260 // If the field is using callbacks, interpret the input field accordingly
261 // and allow the caller to provide custom handling.
262 if (field.callback_type() == internal::CallbackType::kSingleField) {
263 const Callback<StreamEncoder, StreamDecoder>* callback =
264 reinterpret_cast<const Callback<StreamEncoder, StreamDecoder>*>(
265 values.data());
266 PW_TRY(callback->Encode(*this));
267 continue;
268 } else if (field.callback_type() == internal::CallbackType::kOneOfGroup) {
269 const OneOf<StreamEncoder, StreamDecoder>* callback =
270 reinterpret_cast<const OneOf<StreamEncoder, StreamDecoder>*>(
271 values.data());
272 PW_TRY(callback->Encode(*this));
273 continue;
274 }
275
276 switch (field.wire_type()) {
277 case WireType::kFixed64:
278 case WireType::kFixed32: {
279 // Fixed fields call WriteFixed() for singular case and
280 // WritePackedFixed() for repeated fields.
281 PW_CHECK(field.elem_size() == (field.wire_type() == WireType::kFixed32
282 ? sizeof(uint32_t)
283 : sizeof(uint64_t)),
284 "Mismatched message field type and size");
285 if (field.is_fixed_size()) {
286 PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
287 if (static_cast<size_t>(
288 std::count(values.begin(), values.end(), std::byte{0})) <
289 values.size()) {
290 PW_TRY(WritePackedFixed(
291 field.field_number(), values, field.elem_size()));
292 }
293 } else if (field.is_repeated()) {
294 // The struct member for this field is a vector of a type
295 // corresponding to the field element size. Cast to the correct
296 // vector type so we're not performing type aliasing (except for
297 // unsigned vs signed which is explicitly allowed).
298 if (field.elem_size() == sizeof(uint64_t)) {
299 const auto* vector =
300 reinterpret_cast<const pw::Vector<const uint64_t>*>(
301 values.data());
302 if (!vector->empty()) {
303 PW_TRY(WritePackedFixed(
304 field.field_number(),
305 as_bytes(span(vector->data(), vector->size())),
306 field.elem_size()));
307 }
308 } else if (field.elem_size() == sizeof(uint32_t)) {
309 const auto* vector =
310 reinterpret_cast<const pw::Vector<const uint32_t>*>(
311 values.data());
312 if (!vector->empty()) {
313 PW_TRY(WritePackedFixed(
314 field.field_number(),
315 as_bytes(span(vector->data(), vector->size())),
316 field.elem_size()));
317 }
318 }
319 } else if (field.is_optional()) {
320 // The struct member for this field is a std::optional of a type
321 // corresponding to the field element size. Cast to the correct
322 // optional type so we're not performing type aliasing (except for
323 // unsigned vs signed which is explicitly allowed), and write from
324 // a temporary.
325 if (field.elem_size() == sizeof(uint64_t)) {
326 const auto* optional =
327 reinterpret_cast<const std::optional<uint64_t>*>(values.data());
328 if (optional->has_value()) {
329 uint64_t value = optional->value();
330 PW_TRY(
331 WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
332 }
333 } else if (field.elem_size() == sizeof(uint32_t)) {
334 const auto* optional =
335 reinterpret_cast<const std::optional<uint32_t>*>(values.data());
336 if (optional->has_value()) {
337 uint32_t value = optional->value();
338 PW_TRY(
339 WriteFixed(field.field_number(), as_bytes(span(&value, 1))));
340 }
341 }
342 } else {
343 PW_CHECK(values.size() == field.elem_size(),
344 "Mismatched message field type and size");
345 if (static_cast<size_t>(
346 std::count(values.begin(), values.end(), std::byte{0})) <
347 values.size()) {
348 PW_TRY(WriteFixed(field.field_number(), values));
349 }
350 }
351 break;
352 }
353 case WireType::kVarint: {
354 // Varint fields call WriteVarintField() for singular case and
355 // WritePackedVarints() for repeated fields.
356 PW_CHECK(field.elem_size() == sizeof(uint64_t) ||
357 field.elem_size() == sizeof(uint32_t) ||
358 field.elem_size() == sizeof(bool),
359 "Mismatched message field type and size");
360 if (field.is_fixed_size()) {
361 // The struct member for this field is an array of type corresponding
362 // to the field element size. Cast to a span of the correct type over
363 // the array so we're not performing type aliasing (except for
364 // unsigned vs signed which is explicitly allowed).
365 PW_CHECK(field.is_repeated(), "Non-repeated fixed size field");
366 if (static_cast<size_t>(
367 std::count(values.begin(), values.end(), std::byte{0})) ==
368 values.size()) {
369 continue;
370 }
371 if (field.elem_size() == sizeof(uint64_t)) {
372 PW_TRY(WritePackedVarints(
373 field.field_number(),
374 span(reinterpret_cast<const uint64_t*>(values.data()),
375 values.size() / field.elem_size()),
376 field.varint_type()));
377 } else if (field.elem_size() == sizeof(uint32_t)) {
378 PW_TRY(WritePackedVarints(
379 field.field_number(),
380 span(reinterpret_cast<const uint32_t*>(values.data()),
381 values.size() / field.elem_size()),
382 field.varint_type()));
383 } else if (field.elem_size() == sizeof(bool)) {
384 static_assert(sizeof(bool) == sizeof(uint8_t),
385 "bool must be same size as uint8_t");
386 PW_TRY(WritePackedVarints(
387 field.field_number(),
388 span(reinterpret_cast<const uint8_t*>(values.data()),
389 values.size() / field.elem_size()),
390 field.varint_type()));
391 }
392 } else if (field.is_repeated()) {
393 // The struct member for this field is a vector of a type
394 // corresponding to the field element size. Cast to the correct
395 // vector type so we're not performing type aliasing (except for
396 // unsigned vs signed which is explicitly allowed).
397 if (field.elem_size() == sizeof(uint64_t)) {
398 const auto* vector =
399 reinterpret_cast<const pw::Vector<const uint64_t>*>(
400 values.data());
401 if (!vector->empty()) {
402 PW_TRY(WritePackedVarints(field.field_number(),
403 span(vector->data(), vector->size()),
404 field.varint_type()));
405 }
406 } else if (field.elem_size() == sizeof(uint32_t)) {
407 const auto* vector =
408 reinterpret_cast<const pw::Vector<const uint32_t>*>(
409 values.data());
410 if (!vector->empty()) {
411 PW_TRY(WritePackedVarints(field.field_number(),
412 span(vector->data(), vector->size()),
413 field.varint_type()));
414 }
415 } else if (field.elem_size() == sizeof(bool)) {
416 static_assert(sizeof(bool) == sizeof(uint8_t),
417 "bool must be same size as uint8_t");
418 const auto* vector =
419 reinterpret_cast<const pw::Vector<const uint8_t>*>(
420 values.data());
421 if (!vector->empty()) {
422 PW_TRY(WritePackedVarints(field.field_number(),
423 span(vector->data(), vector->size()),
424 field.varint_type()));
425 }
426 }
427 } else if (field.is_optional()) {
428 // The struct member for this field is a std::optional of a type
429 // corresponding to the field element size. Cast to the correct
430 // optional type so we're not performing type aliasing (except for
431 // unsigned vs signed which is explicitly allowed), and write from
432 // a temporary.
433 uint64_t value = 0;
434 if (field.elem_size() == sizeof(uint64_t)) {
435 if (field.varint_type() == VarintType::kUnsigned) {
436 const auto* optional =
437 reinterpret_cast<const std::optional<uint64_t>*>(
438 values.data());
439 if (!optional->has_value()) {
440 continue;
441 }
442 value = optional->value();
443 } else {
444 const auto* optional =
445 reinterpret_cast<const std::optional<int64_t>*>(
446 values.data());
447 if (!optional->has_value()) {
448 continue;
449 }
450 value = field.varint_type() == VarintType::kZigZag
451 ? varint::ZigZagEncode(optional->value())
452 : optional->value();
453 }
454 } else if (field.elem_size() == sizeof(uint32_t)) {
455 if (field.varint_type() == VarintType::kUnsigned) {
456 const auto* optional =
457 reinterpret_cast<const std::optional<uint32_t>*>(
458 values.data());
459 if (!optional->has_value()) {
460 continue;
461 }
462 value = optional->value();
463 } else {
464 const auto* optional =
465 reinterpret_cast<const std::optional<int32_t>*>(
466 values.data());
467 if (!optional->has_value()) {
468 continue;
469 }
470 value = field.varint_type() == VarintType::kZigZag
471 ? varint::ZigZagEncode(optional->value())
472 : optional->value();
473 }
474 } else if (field.elem_size() == sizeof(bool)) {
475 const auto* optional =
476 reinterpret_cast<const std::optional<bool>*>(values.data());
477 if (!optional->has_value()) {
478 continue;
479 }
480 value = optional->value();
481 }
482 PW_TRY(WriteVarintField(field.field_number(), value));
483 } else {
484 // The struct member for this field is a scalar of a type
485 // corresponding to the field element size. Cast to the correct
486 // type to retrieve the value before passing to WriteVarintField()
487 // so we're not performing type aliasing (except for unsigned vs
488 // signed which is explicitly allowed).
489 PW_CHECK(values.size() == field.elem_size(),
490 "Mismatched message field type and size");
491 uint64_t value = 0;
492 if (field.elem_size() == sizeof(uint64_t)) {
493 if (field.varint_type() == VarintType::kZigZag) {
494 value = varint::ZigZagEncode(
495 *reinterpret_cast<const int64_t*>(values.data()));
496 } else if (field.varint_type() == VarintType::kNormal) {
497 value = *reinterpret_cast<const int64_t*>(values.data());
498 } else {
499 value = *reinterpret_cast<const uint64_t*>(values.data());
500 }
501 if (!value) {
502 continue;
503 }
504 } else if (field.elem_size() == sizeof(uint32_t)) {
505 if (field.varint_type() == VarintType::kZigZag) {
506 value = varint::ZigZagEncode(
507 *reinterpret_cast<const int32_t*>(values.data()));
508 } else if (field.varint_type() == VarintType::kNormal) {
509 value = *reinterpret_cast<const int32_t*>(values.data());
510 } else {
511 value = *reinterpret_cast<const uint32_t*>(values.data());
512 }
513 if (!value) {
514 continue;
515 }
516 } else if (field.elem_size() == sizeof(bool)) {
517 value = *reinterpret_cast<const bool*>(values.data());
518 if (!value) {
519 continue;
520 }
521 }
522 PW_TRY(WriteVarintField(field.field_number(), value));
523 }
524 break;
525 }
526 case WireType::kDelimited: {
527 // Delimited fields are always a singular case because of the
528 // inability to cast to a generic vector with an element of a certain
529 // size (we always need a type).
530 PW_CHECK(!field.is_repeated(),
531 "Repeated delimited messages always require a callback");
532 if (field.nested_message_fields()) {
533 // Nested Message. Struct member is an embedded struct for the
534 // nested field. Obtain a nested encoder and recursively call Write()
535 // using the fields table pointer from this field.
536 auto nested_encoder = GetNestedEncoder(field.field_number(),
537 /*write_when_empty=*/false);
538 PW_TRY(nested_encoder.Write(values, *field.nested_message_fields()));
539 } else if (field.is_fixed_size()) {
540 // Fixed-length bytes field. Struct member is a std::array<std::byte>.
541 // Call WriteLengthDelimitedField() to output it to the stream.
542 PW_CHECK(field.elem_size() == sizeof(std::byte),
543 "Mismatched message field type and size");
544 if (static_cast<size_t>(
545 std::count(values.begin(), values.end(), std::byte{0})) <
546 values.size()) {
547 PW_TRY(WriteLengthDelimitedField(field.field_number(), values));
548 }
549 } else {
550 // bytes or string field with a maximum size. Struct member is
551 // pw::Vector<std::byte> for bytes or pw::InlineString<> for string.
552 // Use the contents as a span and call WriteLengthDelimitedField() to
553 // output it to the stream.
554 PW_CHECK(field.elem_size() == sizeof(std::byte),
555 "Mismatched message field type and size");
556 if (field.is_string()) {
557 PW_TRY(WriteStringOrBytes<const InlineString<>>(
558 field.field_number(), values.data()));
559 } else {
560 PW_TRY(WriteStringOrBytes<const Vector<const std::byte>>(
561 field.field_number(), values.data()));
562 }
563 }
564 break;
565 }
566 }
567 }
568
569 ResetOneOfCallbacks(message, table);
570
571 return status_;
572 }
573
ResetOneOfCallbacks(ConstByteSpan message,span<const internal::MessageField> table)574 void StreamEncoder::ResetOneOfCallbacks(
575 ConstByteSpan message, span<const internal::MessageField> table) {
576 for (const auto& field : table) {
577 // Calculate the span of bytes corresponding to the structure field to
578 // read from.
579 ConstByteSpan values =
580 message.subspan(field.field_offset(), field.field_size());
581 PW_CHECK(values.begin() >= message.begin() &&
582 values.end() <= message.end());
583
584 if (field.callback_type() == internal::CallbackType::kOneOfGroup) {
585 const OneOf<StreamEncoder, StreamDecoder>* callback =
586 reinterpret_cast<const OneOf<StreamEncoder, StreamDecoder>*>(
587 values.data());
588 callback->invoked_ = false;
589 }
590 }
591 }
592
593 } // namespace pw::protobuf
594