1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "packet_def.h"
18
19 #include <iomanip>
20 #include <list>
21 #include <set>
22
23 #include "fields/all_fields.h"
24 #include "packet_dependency.h"
25 #include "util.h"
26
PacketDef(std::string name,FieldList fields)27 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)28 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent)
29 : ParentDef(name, fields, parent) {}
30
GetNewField(const std::string &,ParseLocation) const31 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
32 return nullptr; // Packets can't be fields
33 }
34
GenParserDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const35 void PacketDef::GenParserDefinition(std::ostream& s, bool generate_fuzzing,
36 bool generate_tests) const {
37 s << "class " << name_ << "View";
38 if (parent_ != nullptr) {
39 s << " : public " << parent_->name_ << "View {";
40 } else {
41 s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
42 }
43 s << " public:";
44
45 // Specialize function
46 if (parent_ != nullptr) {
47 s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
48 s << "{ return " << name_ << "View(std::move(parent)); }";
49 // CreateOptional
50 s << "static std::optional<" << name_ << "View> CreateOptional(";
51 s << parent_->name_ << "View parent)";
52 s << "{ auto to_validate = " << name_ << "View::Create(std::move(parent));";
53 s << "if (to_validate.IsValid()) { return to_validate; }";
54 s << "else {return {};}}";
55 } else {
56 s << "static " << name_ << "View Create(PacketView<";
57 s << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet)";
58 s << "{ return " << name_ << "View(std::move(packet)); }";
59 // CreateOptional
60 s << "static std::optional<" << name_ << "View> CreateOptional(PacketView<";
61 s << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet)";
62 s << "{ auto to_validate = " << name_ << "View::Create(std::move(packet));";
63 s << "if (to_validate.IsValid()) { return to_validate; }";
64 s << "else {return {};}}";
65 }
66
67 if (generate_fuzzing || generate_tests) {
68 GenTestingParserFromBytes(s);
69 }
70
71 std::set<std::string> fixed_types = {
72 FixedScalarField::kFieldType,
73 FixedEnumField::kFieldType,
74 };
75
76 // Print all of the public fields which are all the fields minus the fixed fields.
77 const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
78 bool has_fixed_fields = public_fields.size() != fields_.size();
79 for (const auto& field : public_fields) {
80 GenParserFieldGetter(s, field);
81 s << "\n";
82 }
83 GenValidator(s);
84 s << "\n";
85
86 s << " public:";
87 GenParserToString(s);
88 s << "\n";
89
90 s << " protected:\n";
91 // Constructor from a View
92 if (parent_ != nullptr) {
93 s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
94 s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
95 } else {
96 s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!")
97 << "kLittleEndian> packet) ";
98 s << " : PacketView<" << (is_little_endian_ ? "" : "!")
99 << "kLittleEndian>(packet) { was_validated_ = false;}";
100 }
101
102 // Print the private fields which are the fixed fields.
103 if (has_fixed_fields) {
104 const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
105 s << " private:\n";
106 for (const auto& field : private_fields) {
107 GenParserFieldGetter(s, field);
108 s << "\n";
109 }
110 }
111 s << "};\n";
112 }
113
GenTestingParserFromBytes(std::ostream & s) const114 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
115 s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
116
117 s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
118 s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
119 s << "return " << name_ << "View::Create(";
120 auto ancestor_ptr = parent_;
121 size_t parent_parens = 0;
122 while (ancestor_ptr != nullptr) {
123 s << ancestor_ptr->name_ << "View::Create(";
124 parent_parens++;
125 ancestor_ptr = ancestor_ptr->parent_;
126 }
127 s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
128 for (size_t i = 0; i < parent_parens; i++) {
129 s << ")";
130 }
131 s << ");";
132 s << "}";
133
134 s << "\n#endif\n";
135 }
136
GenParserDefinitionPybind11(std::ostream & s) const137 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
138 s << "py::class_<" << name_ << "View";
139 if (parent_ != nullptr) {
140 s << ", " << parent_->name_ << "View";
141 } else {
142 s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
143 }
144 s << ">(m, \"" << name_ << "View\")";
145 if (parent_ != nullptr) {
146 s << ".def(py::init([](" << parent_->name_ << "View parent) {";
147 } else {
148 s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!")
149 << "kLittleEndian> parent) {";
150 }
151 s << "auto view =" << name_ << "View::Create(std::move(parent));";
152 s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
153 s << "return view; }))";
154
155 s << ".def(py::init(&" << name_ << "View::Create))";
156 std::set<std::string> protected_field_types = {
157 FixedScalarField::kFieldType,
158 FixedEnumField::kFieldType,
159 SizeField::kFieldType,
160 CountField::kFieldType,
161 };
162 const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
163 for (const auto& field : public_fields) {
164 auto getter_func_name = field->GetGetterFunctionName();
165 if (getter_func_name.empty()) {
166 continue;
167 }
168 s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
169 }
170 s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
171 s << ";\n";
172 }
173
GenParserFieldGetter(std::ostream & s,const PacketField * field) const174 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
175 // Start field offset
176 auto start_field_offset = GetOffsetForField(field->GetName(), false);
177 auto end_field_offset = GetOffsetForField(field->GetName(), true);
178
179 if (start_field_offset.empty() && end_field_offset.empty()) {
180 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
181 << "no method exists to determine field location from begin() or end().\n";
182 }
183
184 field->GenGetter(s, start_field_offset, end_field_offset);
185 }
186
GetDefinitionType() const187 TypeDef::Type PacketDef::GetDefinitionType() const { return TypeDef::Type::PACKET; }
188
GenValidator(std::ostream & s) const189 void PacketDef::GenValidator(std::ostream& s) const {
190 // Get the static offset for all of our fields.
191 int bits_size = 0;
192 for (const auto& field : fields_) {
193 if (field->GetFieldType() != PaddingField::kFieldType) {
194 bits_size += field->GetSize().bits();
195 }
196 }
197
198 // Generate the public validator IsValid().
199 // The method only needs to be generated for the top most class.
200 if (parent_ == nullptr) {
201 s << "bool IsValid() {" << std::endl;
202 s << " if (was_validated_) {" << std::endl;
203 s << " return true;" << std::endl;
204 s << " } else {" << std::endl;
205 s << " was_validated_ = true;" << std::endl;
206 s << " return (was_validated_ = Validate());" << std::endl;
207 s << " }" << std::endl;
208 s << "}" << std::endl;
209 }
210
211 // Generate the private validator Validate().
212 // The method is overridden by all child classes.
213 s << "protected:" << std::endl;
214 if (parent_ == nullptr) {
215 s << "virtual bool Validate() const {" << std::endl;
216 } else {
217 s << "bool Validate() const override {" << std::endl;
218 s << " if (!" << parent_->name_ << "View::Validate()) {" << std::endl;
219 s << " return false;" << std::endl;
220 s << " }" << std::endl;
221 }
222
223 // Offset by the parents known size. We know that any dynamic fields can
224 // already be called since the parent must have already been validated by
225 // this point.
226 auto parent_size = Size(0);
227 if (parent_ != nullptr) {
228 parent_size = parent_->GetSize(true);
229 }
230
231 s << "auto it = begin() + (" << parent_size << ") / 8;";
232
233 // Check if you can extract the static fields.
234 // At this point you know you can use the size getters without crashing
235 // as long as they follow the instruction that size fields cant come before
236 // their corrisponding variable length field.
237 s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
238 s << "if (it > end()) return false;";
239
240 // For any variable length fields, use their size check.
241 for (const auto& field : fields_) {
242 if (field->GetFieldType() == ChecksumStartField::kFieldType) {
243 auto offset = GetOffsetForField(field->GetName(), false);
244 if (!offset.empty()) {
245 s << "size_t sum_index = (" << offset << ") / 8;";
246 } else {
247 offset = GetOffsetForField(field->GetName(), true);
248 if (offset.empty()) {
249 ERROR(field) << "Checksum Start Field offset can not be determined.";
250 }
251 s << "size_t sum_index = size() - (" << offset << ") / 8;";
252 }
253
254 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
255 const auto& started_field = fields_.GetField(field_name);
256 if (started_field == nullptr) {
257 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "("
258 << field->GetName() << ")";
259 }
260 auto end_offset = GetOffsetForField(started_field->GetName(), false);
261 if (!end_offset.empty()) {
262 s << "size_t end_sum_index = (" << end_offset << ") / 8;";
263 } else {
264 end_offset = GetOffsetForField(started_field->GetName(), true);
265 if (end_offset.empty()) {
266 ERROR(started_field) << "Checksum Field end_offset can not be determined.";
267 }
268 s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset
269 << ") / 8;";
270 }
271 s << "if (end_sum_index >= size()) { return false; }";
272 if (is_little_endian_) {
273 s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
274 } else {
275 s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
276 }
277 s << started_field->GetDataType() << " checksum;";
278 s << "checksum.Initialize();";
279 s << "for (uint8_t byte : checksum_view) { ";
280 s << "checksum.AddByte(byte);}";
281 s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
282 << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
283
284 continue;
285 }
286
287 auto field_size = field->GetSize();
288 // Fixed size fields have already been handled.
289 if (!field_size.has_dynamic()) {
290 continue;
291 }
292
293 // Custom fields with dynamic size must have the offset for the field passed in as well
294 // as the end iterator so that they may ensure that they don't try to read past the end.
295 // Custom fields with fixed sizes will be handled in the static offset checking.
296 if (field->GetFieldType() == CustomField::kFieldType) {
297 // Check if we can determine offset from begin(), otherwise error because by this point,
298 // the size of the custom field is unknown and can't be subtracted from end() to get the
299 // offset.
300 auto offset = GetOffsetForField(field->GetName(), false);
301 if (offset.empty()) {
302 ERROR(field) << "Custom Field offset can not be determined from begin().";
303 }
304
305 if (offset.bits() % 8 != 0) {
306 ERROR(field) << "Custom fields must be byte aligned.";
307 }
308
309 // Custom fields are special as their size field takes an argument.
310 const auto& custom_size_var = field->GetName() + "_size";
311 s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
312 s << "(begin() + (" << offset << ") / 8);";
313
314 s << "if (!" << custom_size_var << ".has_value()) { return false; }";
315 s << "it += *" << custom_size_var << ";";
316 s << "if (it > end()) return false;";
317 continue;
318 } else {
319 s << "it += (" << field_size.dynamic_string() << ") / 8;";
320 s << "if (it > end()) return false;";
321 }
322 }
323
324 // Validate constraints after validating the size
325 if (parent_constraints_.size() > 0 && parent_ == nullptr) {
326 ERROR() << "Can't have a constraint on a NULL parent";
327 }
328
329 for (const auto& constraint : parent_constraints_) {
330 s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
331 const auto& field = parent_->GetParamList().GetField(constraint.first);
332 if (field->GetFieldType() == ScalarField::kFieldType) {
333 s << std::get<int64_t>(constraint.second);
334 } else {
335 s << std::get<std::string>(constraint.second);
336 }
337 s << ") return false;";
338 }
339
340 // Validate the packets fields last
341 for (const auto& field : fields_) {
342 field->GenValidator(s);
343 s << "\n";
344 }
345
346 s << "return true;";
347 s << "}\n";
348 if (parent_ == nullptr) {
349 s << "bool was_validated_{false};\n";
350 }
351 }
352
GenParserToString(std::ostream & s) const353 void PacketDef::GenParserToString(std::ostream& s) const {
354 s << "virtual std::string ToString() const " << (parent_ != nullptr ? " override" : "") << " {";
355 s << "std::stringstream ss;";
356 s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
357
358 if (fields_.size() > 0) {
359 s << "ss << \"\" ";
360 bool firstfield = true;
361 for (const auto& field : fields_) {
362 if (field->GetFieldType() == ReservedField::kFieldType ||
363 field->GetFieldType() == FixedScalarField::kFieldType ||
364 field->GetFieldType() == ChecksumStartField::kFieldType) {
365 continue;
366 }
367
368 s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
369
370 field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
371
372 if (firstfield) {
373 firstfield = false;
374 }
375 }
376 s << ";";
377 }
378
379 s << "ss << \" }\";";
380 s << "return ss.str();";
381 s << "}\n";
382 }
383
GenBuilderDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const384 void PacketDef::GenBuilderDefinition(std::ostream& s, bool generate_fuzzing,
385 bool generate_tests) const {
386 s << "class " << name_ << "Builder";
387 if (parent_ != nullptr) {
388 s << " : public " << parent_->name_ << "Builder";
389 } else {
390 if (is_little_endian_) {
391 s << " : public PacketBuilder<kLittleEndian>";
392 } else {
393 s << " : public PacketBuilder<!kLittleEndian>";
394 }
395 }
396 s << " {";
397 s << " public:";
398 s << " virtual ~" << name_ << "Builder() = default;";
399
400 if (!fields_.HasBody()) {
401 GenBuilderCreate(s);
402 s << "\n";
403
404 if (generate_fuzzing || generate_tests) {
405 GenTestingFromView(s);
406 s << "\n";
407 }
408 }
409
410 GenSerialize(s);
411 s << "\n";
412
413 GenSize(s);
414 s << "\n";
415
416 s << " protected:\n";
417 GenBuilderConstructor(s);
418 s << "\n";
419
420 GenBuilderParameterChecker(s);
421 s << "\n";
422
423 GenMembers(s);
424 s << "};\n";
425
426 if (generate_tests) {
427 GenTestDefine(s);
428 s << "\n";
429 }
430
431 if (generate_fuzzing || generate_tests) {
432 GenReflectTestDefine(s);
433 s << "\n";
434 }
435
436 if (generate_fuzzing) {
437 GenFuzzTestDefine(s);
438 s << "\n";
439 }
440 }
441
GenTestingFromView(std::ostream & s) const442 void PacketDef::GenTestingFromView(std::ostream& s) const {
443 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
444
445 s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
446 s << "if (!view.IsValid()) return nullptr;";
447 s << "return " << name_ << "Builder::Create(";
448 FieldList params = GetParamList().GetFieldsWithoutTypes({
449 BodyField::kFieldType,
450 });
451 for (std::size_t i = 0; i < params.size(); i++) {
452 params[i]->GenBuilderParameterFromView(s);
453 if (i != params.size() - 1) {
454 s << ", ";
455 }
456 }
457 s << ");";
458 s << "}";
459
460 s << "\n#endif\n";
461 }
462
GenBuilderDefinitionPybind11(std::ostream & s) const463 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
464 s << "py::class_<" << name_ << "Builder";
465 if (parent_ != nullptr) {
466 s << ", " << parent_->name_ << "Builder";
467 } else {
468 if (is_little_endian_) {
469 s << ", PacketBuilder<kLittleEndian>";
470 } else {
471 s << ", PacketBuilder<!kLittleEndian>";
472 }
473 }
474 s << ", std::shared_ptr<" << name_ << "Builder>";
475 s << ">(m, \"" << name_ << "Builder\")";
476 if (!fields_.HasBody()) {
477 GenBuilderCreatePybind11(s);
478 }
479 s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
480 s << "std::vector<uint8_t> bytes;";
481 s << "BitInserter bi(bytes);";
482 s << "builder.Serialize(bi);";
483 s << "return bytes;})";
484 s << ";\n";
485 }
486
GenTestDefine(std::ostream & s) const487 void PacketDef::GenTestDefine(std::ostream& s) const {
488 s << "#ifdef PACKET_TESTING\n";
489 s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
490 s << "class " << name_
491 << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
492 s << "public: ";
493 s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
494 s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
495 s << "if (!view.IsValid()) { log::info(\"Invalid Packet Bytes (size = {})\", view.size());";
496 s << "for (size_t i = 0; i < view.size(); i++) { log::info(\"{:5}:{:02x}\", i, *(view.begin() + "
497 "i)); }}";
498 s << "ASSERT_TRUE(view.IsValid());";
499 s << "auto packet = " << name_ << "Builder::FromView(view);";
500 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = "
501 "std::make_shared<std::vector<uint8_t>>();";
502 s << "packet_bytes->reserve(packet->size());";
503 s << "BitInserter it(*packet_bytes);";
504 s << "packet->Serialize(it);";
505 s << "ASSERT_EQ(*packet_bytes, captured_packet);";
506 s << "}";
507 s << "};";
508 s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
509 s << "CompareBytes(GetParam());";
510 s << "}";
511 s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
512 s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
513 int i = 0;
514 for (const auto& bytes : test_cases_) {
515 s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
516 s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
517 s << name_ << "_test_bytes_" << i << ",";
518 s << name_ << "_test_bytes_" << i << " + sizeof(";
519 s << name_ << "_test_bytes_" << i << ") - 1);";
520 i++;
521 }
522 if (!test_cases_.empty()) {
523 i = 0;
524 s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
525 for (auto bytes : test_cases_) {
526 if (i > 0) {
527 s << ",";
528 }
529 s << name_ << "_test_vec_" << i++;
530 }
531 s << ");";
532 }
533 s << "\n#endif";
534 }
535
GenReflectTestDefine(std::ostream & s) const536 void PacketDef::GenReflectTestDefine(std::ostream& s) const {
537 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
538 s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
539 s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
540 s << "auto vec = std::vector<uint8_t>(data, data + size);";
541 s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
542 s << "if (!view.IsValid()) { return; }";
543 s << "auto packet = " << name_ << "Builder::FromView(view);";
544 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = "
545 "std::make_shared<std::vector<uint8_t>>();";
546 s << "packet_bytes->reserve(packet->size());";
547 s << "BitInserter it(*packet_bytes);";
548 s << "packet->Serialize(it);";
549 s << "}";
550 s << "\n#endif\n";
551 }
552
GenFuzzTestDefine(std::ostream & s) const553 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
554 s << "#ifdef PACKET_FUZZ_TESTING\n";
555 s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
556 s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
557 s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
558 s << "public: ";
559 s << "explicit " << name_
560 << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& "
561 "fuzz_test_registry) {";
562 s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
563 s << "}}; ";
564 s << name_ << "ReflectionFuzzTestRegistrant " << name_
565 << "_reflection_fuzz_test_registrant(REGISTRY);";
566 s << "\n#endif";
567 }
568
GetParametersToValidate() const569 FieldList PacketDef::GetParametersToValidate() const {
570 FieldList params_to_validate;
571 for (const auto& field : GetParamList()) {
572 if (field->HasParameterValidator()) {
573 params_to_validate.AppendField(field);
574 }
575 }
576 return params_to_validate;
577 }
578
GenBuilderCreate(std::ostream & s) const579 void PacketDef::GenBuilderCreate(std::ostream& s) const {
580 s << "static std::unique_ptr<" << name_ << "Builder> Create(";
581
582 auto params = GetParamList();
583 for (std::size_t i = 0; i < params.size(); i++) {
584 params[i]->GenBuilderParameter(s);
585 if (i != params.size() - 1) {
586 s << ", ";
587 }
588 }
589 s << ") {";
590
591 // Call the constructor
592 s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
593
594 params = params.GetFieldsWithoutTypes({
595 PayloadField::kFieldType,
596 BodyField::kFieldType,
597 });
598 // Add the parameters.
599 for (std::size_t i = 0; i < params.size(); i++) {
600 if (params[i]->BuilderParameterMustBeMoved()) {
601 s << "std::move(" << params[i]->GetName() << ")";
602 } else {
603 s << params[i]->GetName();
604 }
605 if (i != params.size() - 1) {
606 s << ", ";
607 }
608 }
609
610 s << "));";
611 if (fields_.HasPayload()) {
612 s << "builder->payload_ = std::move(payload);";
613 }
614 s << "return builder;";
615 s << "}\n";
616 }
617
GenBuilderCreatePybind11(std::ostream & s) const618 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
619 s << ".def(py::init([](";
620 auto params = GetParamList();
621 std::vector<std::string> constructor_args;
622 for (const auto& param : params) {
623 std::stringstream ss;
624 auto param_type = param->GetBuilderParameterType();
625 if (param_type.empty()) {
626 continue;
627 }
628 // Use shared_ptr instead of unique_ptr for the Python interface
629 if (param->BuilderParameterMustBeMoved()) {
630 param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
631 }
632 ss << param_type << " " << param->GetName();
633 constructor_args.push_back(ss.str());
634 }
635 s << util::StringJoin(",", constructor_args) << "){";
636
637 // Deal with move only args
638 for (const auto& param : params) {
639 std::stringstream ss;
640 auto param_type = param->GetBuilderParameterType();
641 if (param_type.empty()) {
642 continue;
643 }
644 if (!param->BuilderParameterMustBeMoved()) {
645 continue;
646 }
647 auto move_only_param_name = param->GetName() + "_move_only";
648 s << param_type << " " << move_only_param_name << ";";
649 if (param->IsContainerField()) {
650 // Assume single layer container and copy it
651 auto struct_type = param->GetElementField()->GetDataType();
652 struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
653 struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
654 s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
655 // Serialize each struct
656 s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
657 s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
658 s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
659 s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
660 // Parse it again
661 s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName()
662 << "_bytes);";
663 s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
664 s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
665 // Push it into a new container
666 if (param->GetFieldType() == VectorField::kFieldType) {
667 s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
668 } else if (param->GetFieldType() == ArrayField::kFieldType) {
669 s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
670 } else {
671 ERROR() << param << " is not supported by Pybind11";
672 }
673 s << "}";
674 } else {
675 // Serialize the parameter and pass the bytes in a RawBuilder
676 s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
677 s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
678 s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
679 s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
680 s << move_only_param_name << " = ";
681 s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
682 }
683 }
684 s << "return " << name_ << "Builder::Create(";
685 std::vector<std::string> builder_vars;
686 for (const auto& param : params) {
687 std::stringstream ss;
688 auto param_type = param->GetBuilderParameterType();
689 if (param_type.empty()) {
690 continue;
691 }
692 auto param_name = param->GetName();
693 if (param->BuilderParameterMustBeMoved()) {
694 ss << "std::move(" << param_name << "_move_only)";
695 } else {
696 ss << param_name;
697 }
698 builder_vars.push_back(ss.str());
699 }
700 s << util::StringJoin(",", builder_vars) << ");}";
701 s << "))";
702 }
703
GenBuilderParameterChecker(std::ostream & s) const704 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
705 FieldList params_to_validate = GetParametersToValidate();
706
707 // Skip writing this function if there is nothing to validate.
708 if (params_to_validate.size() == 0) {
709 return;
710 }
711
712 // Generate function arguments.
713 s << "void CheckParameterValues(";
714 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
715 params_to_validate[i]->GenBuilderParameter(s);
716 if (i != params_to_validate.size() - 1) {
717 s << ", ";
718 }
719 }
720 s << ") {";
721
722 // Check the parameters.
723 for (const auto& field : params_to_validate) {
724 field->GenParameterValidator(s);
725 }
726 s << "}\n";
727 }
728
GenBuilderConstructor(std::ostream & s) const729 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
730 s << "explicit " << name_ << "Builder(";
731
732 // Generate the constructor parameters.
733 auto params = GetParamList().GetFieldsWithoutTypes({
734 PayloadField::kFieldType,
735 BodyField::kFieldType,
736 });
737 for (std::size_t i = 0; i < params.size(); i++) {
738 params[i]->GenBuilderParameter(s);
739 if (i != params.size() - 1) {
740 s << ", ";
741 }
742 }
743 if (params.size() > 0 || parent_constraints_.size() > 0) {
744 s << ") :";
745 } else {
746 s << ")";
747 }
748
749 // Get the list of parent params to call the parent constructor with.
750 FieldList parent_params;
751 if (parent_ != nullptr) {
752 // Pass parameters to the parent constructor
753 s << parent_->name_ << "Builder(";
754 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
755 PayloadField::kFieldType,
756 BodyField::kFieldType,
757 });
758
759 // Go through all the fields and replace constrained fields with fixed values
760 // when calling the parent constructor.
761 for (std::size_t i = 0; i < parent_params.size(); i++) {
762 const auto& field = parent_params[i];
763 const auto& constraint = parent_constraints_.find(field->GetName());
764 if (constraint != parent_constraints_.end()) {
765 if (field->GetFieldType() == ScalarField::kFieldType) {
766 s << std::get<int64_t>(constraint->second);
767 } else if (field->GetFieldType() == EnumField::kFieldType) {
768 s << std::get<std::string>(constraint->second);
769 } else {
770 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
771 }
772
773 s << "/* " << field->GetName() << "_ */";
774 } else {
775 s << field->GetName();
776 }
777
778 if (i != parent_params.size() - 1) {
779 s << ", ";
780 }
781 }
782 s << ") ";
783 }
784
785 // Build a list of parameters that excludes all parent parameters.
786 FieldList saved_params;
787 for (const auto& field : params) {
788 if (parent_params.GetField(field->GetName()) == nullptr) {
789 saved_params.AppendField(field);
790 }
791 }
792 if (parent_ != nullptr && saved_params.size() > 0) {
793 s << ",";
794 }
795 for (std::size_t i = 0; i < saved_params.size(); i++) {
796 const auto& saved_param_name = saved_params[i]->GetName();
797 if (saved_params[i]->BuilderParameterMustBeMoved()) {
798 s << saved_param_name << "_(std::move(" << saved_param_name << "))";
799 } else {
800 s << saved_param_name << "_(" << saved_param_name << ")";
801 }
802 if (i != saved_params.size() - 1) {
803 s << ",";
804 }
805 }
806 s << " {";
807
808 FieldList params_to_validate = GetParametersToValidate();
809
810 if (params_to_validate.size() > 0) {
811 s << "CheckParameterValues(";
812 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
813 s << params_to_validate[i]->GetName() << "_";
814 if (i != params_to_validate.size() - 1) {
815 s << ", ";
816 }
817 }
818 s << ");";
819 }
820
821 s << "}\n";
822 }
823