1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15
16 #include <cstdint>
17
18 #include "pw_function/function.h"
19 #include "pw_preprocessor/compiler.h"
20 #include "pw_protobuf/wire_format.h"
21 #include "pw_result/result.h"
22 #include "pw_span/span.h"
23 #include "pw_status/status.h"
24
25 // TODO: b/259746255 - Remove this manual application of -Wconversion when all
26 // of
27 // Pigweed builds with it.
28 PW_MODIFY_DIAGNOSTICS_PUSH();
29 PW_MODIFY_DIAGNOSTIC(error, "-Wconversion");
30
31 namespace pw::protobuf {
32 namespace internal {
33
34 // Varints can be encoded as an unsigned type, a signed type with normal
35 // encoding, or a signed type with zigzag encoding.
36 enum class VarintType {
37 kUnsigned = 0,
38 kNormal = 1,
39 kZigZag = 2,
40 };
41
42 enum class CallbackType {
43 kNone = 0,
44 kSingleField = 1,
45 kOneOfGroup = 2,
46 };
47
48 // Represents a field in a code generated message struct that can be the target
49 // for decoding or source of encoding.
50 //
51 // An instance of this class exists for every field in every protobuf in the
52 // binary, thus it is size critical to ensure efficiency while retaining enough
53 // information to describe the layout of the generated message struct.
54 //
55 // Limitations imposed:
56 // - Element size of a repeated fields must be no larger than 15 bytes.
57 // (8 byte int64/fixed64/double is the largest supported element).
58 // - Individual field size (including repeated and nested messages) must be no
59 // larger than 64 KB. (This is already the maximum size of pw::Vector).
60 //
61 // A complete codegen struct is represented by a span<MessageField>,
62 // holding a pointer to the MessageField members themselves, and the number of
63 // fields in the struct. These spans are global data, one span per protobuf
64 // message (including the size), and one MessageField per field in the message.
65 //
66 // Nested messages are handled with a pointer from the MessageField in the
67 // parent to a pointer to the (global data) span. Since the size of the nested
68 // message is stored as part of the global span, the cost of a nested message
69 // is only the size of a pointer to that span.
70 class MessageField {
71 public:
72 static constexpr unsigned int kMaxFieldSize = (1u << 16) - 1;
73
MessageField(uint32_t field_number,WireType wire_type,size_t elem_size,VarintType varint_type,bool is_string,bool is_fixed_size,bool is_repeated,bool is_optional,CallbackType callback_type,size_t field_offset,size_t field_size,const span<const MessageField> * nested_message_fields)74 constexpr MessageField(uint32_t field_number,
75 WireType wire_type,
76 size_t elem_size,
77 VarintType varint_type,
78 bool is_string,
79 bool is_fixed_size,
80 bool is_repeated,
81 bool is_optional,
82 CallbackType callback_type,
83 size_t field_offset,
84 size_t field_size,
85 const span<const MessageField>* nested_message_fields)
86 : field_number_(field_number),
87 field_info_(static_cast<uint32_t>(wire_type) << kWireTypeShift |
88 static_cast<uint32_t>(elem_size) << kElemSizeShift |
89 static_cast<uint32_t>(varint_type) << kVarintTypeShift |
90 static_cast<uint32_t>(is_string) << kIsStringShift |
91 static_cast<uint32_t>(is_fixed_size) << kIsFixedSizeShift |
92 static_cast<uint32_t>(is_repeated) << kIsRepeatedShift |
93 static_cast<uint32_t>(is_optional) << kIsOptionalShift |
94 static_cast<uint32_t>(callback_type) << kCallbackTypeShift |
95 static_cast<uint32_t>(field_size) << kFieldSizeShift),
96 field_offset_(field_offset),
97 nested_message_fields_(nested_message_fields) {}
98
field_number()99 constexpr uint32_t field_number() const { return field_number_; }
wire_type()100 constexpr WireType wire_type() const {
101 return static_cast<WireType>((field_info_ >> kWireTypeShift) &
102 kWireTypeMask);
103 }
elem_size()104 constexpr size_t elem_size() const {
105 return (field_info_ >> kElemSizeShift) & kElemSizeMask;
106 }
varint_type()107 constexpr VarintType varint_type() const {
108 return static_cast<VarintType>((field_info_ >> kVarintTypeShift) &
109 kVarintTypeMask);
110 }
is_string()111 constexpr bool is_string() const {
112 return (field_info_ >> kIsStringShift) & 1;
113 }
is_fixed_size()114 constexpr bool is_fixed_size() const {
115 return (field_info_ >> kIsFixedSizeShift) & 1;
116 }
is_repeated()117 constexpr bool is_repeated() const {
118 return (field_info_ >> kIsRepeatedShift) & 1;
119 }
is_optional()120 constexpr bool is_optional() const {
121 return (field_info_ >> kIsOptionalShift) & 1;
122 }
callback_type()123 constexpr CallbackType callback_type() const {
124 return static_cast<CallbackType>((field_info_ >> kCallbackTypeShift) &
125 kCallbackTypeMask);
126 }
field_offset()127 constexpr size_t field_offset() const { return field_offset_; }
field_size()128 constexpr size_t field_size() const {
129 return (field_info_ >> kFieldSizeShift) & kFieldSizeMask;
130 }
nested_message_fields()131 constexpr const span<const MessageField>* nested_message_fields() const {
132 return nested_message_fields_;
133 }
134
135 constexpr bool operator==(uint32_t field_number) const {
136 return field_number == field_number_;
137 }
138
139 private:
140 // field_info_ packs multiple fields into a single word as follows:
141 //
142 // wire_type : 3
143 // varint_type : 2
144 // is_string : 1
145 // is_fixed_size : 1
146 // is_repeated : 1
147 // [unused space] : 1
148 // -
149 // elem_size : 4
150 // callback_type : 2
151 // is_optional : 1
152 // -
153 // field_size : 16
154 //
155 // The protobuf field type is spread among a few fields (wire_type,
156 // varint_type, is_string, elem_size). The exact field type (e.g. int32, bool,
157 // message, etc.), from which all of that information can be derived, can be
158 // represented in 4 bits. If more bits are needed in the future, these could
159 // be consolidated into a single field type enum.
160 static constexpr unsigned int kWireTypeShift = 29u;
161 static constexpr unsigned int kWireTypeMask = (1u << 3) - 1;
162 static constexpr unsigned int kVarintTypeShift = 27u;
163 static constexpr unsigned int kVarintTypeMask = (1u << 2) - 1;
164 static constexpr unsigned int kIsStringShift = 26u;
165 static constexpr unsigned int kIsFixedSizeShift = 25u;
166 static constexpr unsigned int kIsRepeatedShift = 24u;
167 // Unused space: bit 23 (previously use_callback).
168 static constexpr unsigned int kElemSizeShift = 19u;
169 static constexpr unsigned int kElemSizeMask = (1u << 4) - 1;
170 static constexpr unsigned int kCallbackTypeShift = 17;
171 static constexpr unsigned int kCallbackTypeMask = (1u << 2) - 1;
172 static constexpr unsigned int kIsOptionalShift = 16u;
173 static constexpr unsigned int kFieldSizeShift = 0u;
174 static constexpr unsigned int kFieldSizeMask = kMaxFieldSize;
175
176 uint32_t field_number_;
177 uint32_t field_info_;
178 size_t field_offset_;
179 // TODO: b/234875722 - Could be replaced by a class MessageDescriptor*
180 const span<const MessageField>* nested_message_fields_;
181 };
182 static_assert(sizeof(MessageField) <= sizeof(size_t) * 4,
183 "MessageField should be four words or less");
184
185 template <typename...>
186 constexpr std::false_type kInvalidMessageStruct{};
187
188 } // namespace internal
189
190 // Callback for a structure member that cannot be represented by a data type.
191 // Holds either a callback for encoding a field, or a callback for decoding
192 // a field.
193 template <typename StreamEncoder, typename StreamDecoder>
194 union Callback {
Callback()195 constexpr Callback() : encode_() {}
~Callback()196 ~Callback() { encode_ = nullptr; }
197
198 // Set the encoder callback.
SetEncoder(Function<Status (StreamEncoder & encoder)> && encode)199 void SetEncoder(Function<Status(StreamEncoder& encoder)>&& encode) {
200 encode_ = std::move(encode);
201 }
202
203 // Set the decoder callback.
SetDecoder(Function<Status (StreamDecoder & decoder)> && decode)204 void SetDecoder(Function<Status(StreamDecoder& decoder)>&& decode) {
205 decode_ = std::move(decode);
206 }
207
208 // Allow moving of callbacks by moving the member.
209 constexpr Callback(Callback&& other) = default;
210 constexpr Callback& operator=(Callback&& other) = default;
211
212 // Copying a callback does not copy the functions.
Callback(const Callback &)213 constexpr Callback(const Callback&) : encode_() {}
214 constexpr Callback& operator=(const Callback&) {
215 encode_ = nullptr;
216 return *this;
217 }
218
219 // Evaluate to true if the encoder or decoder callback is set.
220 explicit operator bool() const { return encode_ || decode_; }
221
222 private:
223 friend StreamDecoder;
224 friend StreamEncoder;
225
226 // Called by StreamEncoder to encode the structure member.
227 // Returns OkStatus() if this has not been set by the caller, the default
228 // behavior of a field without an encoder is the same as default-initialized
229 // field.
Encode(StreamEncoder & encoder)230 Status Encode(StreamEncoder& encoder) const {
231 if (encode_) {
232 return encode_(encoder);
233 }
234 return OkStatus();
235 }
236
237 // Called by StreamDecoder to decode the structure member when the field
238 // is present. Returns DataLoss() if this has not been set by the caller.
Decode(StreamDecoder & decoder)239 Status Decode(StreamDecoder& decoder) const {
240 if (decode_) {
241 return decode_(decoder);
242 }
243 return Status::DataLoss();
244 }
245
246 Function<Status(StreamEncoder& encoder)> encode_;
247 Function<Status(StreamDecoder& decoder)> decode_;
248 };
249
250 enum class NullFields : uint32_t {};
251
252 /// Callback for a oneof structure member.
253 /// A oneof callback will only be invoked once per struct member.
254 template <typename StreamEncoder,
255 typename StreamDecoder,
256 typename Fields = NullFields>
257 struct OneOf {
258 public:
OneOfOneOf259 constexpr OneOf() : invoked_(false), encode_() {}
~OneOfOneOf260 ~OneOf() { encode_ = nullptr; }
261
262 // Set the encoder callback.
SetEncoderOneOf263 void SetEncoder(Function<Status(StreamEncoder& encoder)>&& encode) {
264 encode_ = std::move(encode);
265 }
266
267 // Set the decoder callback.
SetDecoderOneOf268 void SetDecoder(
269 Function<Status(Fields field, StreamDecoder& decoder)>&& decode) {
270 decode_ = std::move(decode);
271 }
272
273 // Allow moving of callbacks by moving the member.
274 constexpr OneOf(OneOf&& other) = default;
275 constexpr OneOf& operator=(OneOf&& other) = default;
276
277 // Copying a callback does not copy the functions.
OneOfOneOf278 constexpr OneOf(const OneOf&) : encode_() {}
279 constexpr OneOf& operator=(const OneOf&) {
280 encode_ = nullptr;
281 return *this;
282 }
283
284 // Evaluate to true if the encoder or decoder callback is set.
285 explicit operator bool() const { return encode_ || decode_; }
286
287 private:
288 friend StreamDecoder;
289 friend StreamEncoder;
290
ResetForNewWriteOneOf291 constexpr void ResetForNewWrite() const { invoked_ = false; }
292
EncodeOneOf293 Status Encode(StreamEncoder& encoder) const {
294 if (encode_) {
295 if (invoked_) {
296 // The oneof has already been encoded.
297 return OkStatus();
298 }
299
300 invoked_ = true;
301 return encode_(encoder);
302 }
303 return OkStatus();
304 }
305
DecodeOneOf306 Status Decode(Fields field, StreamDecoder& decoder) const {
307 if (decode_) {
308 if (invoked_) {
309 // Multiple fields from the same oneof exist in the serialized message.
310 return Status::DataLoss();
311 }
312
313 invoked_ = true;
314 return decode_(field, decoder);
315 }
316 return OkStatus();
317 }
318
319 mutable bool invoked_;
320 union {
321 Function<Status(StreamEncoder& encoder)> encode_;
322 Function<Status(Fields field, StreamDecoder& decoder)> decode_;
323 };
324 };
325
326 template <typename T>
IsTriviallyComparable()327 constexpr bool IsTriviallyComparable() {
328 static_assert(internal::kInvalidMessageStruct<T>,
329 "Not a generated message struct");
330 return false;
331 }
332
333 } // namespace pw::protobuf
334
335 PW_MODIFY_DIAGNOSTICS_POP();
336