1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "packet_def.h"
18 
19 #include <iomanip>
20 #include <list>
21 #include <set>
22 
23 #include "fields/all_fields.h"
24 #include "packet_dependency.h"
25 #include "util.h"
26 
PacketDef(std::string name,FieldList fields)27 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)28 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent)
29     : ParentDef(name, fields, parent) {}
30 
GetNewField(const std::string &,ParseLocation) const31 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
32   return nullptr;  // Packets can't be fields
33 }
34 
GenParserDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const35 void PacketDef::GenParserDefinition(std::ostream& s, bool generate_fuzzing,
36                                     bool generate_tests) const {
37   s << "class " << name_ << "View";
38   if (parent_ != nullptr) {
39     s << " : public " << parent_->name_ << "View {";
40   } else {
41     s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
42   }
43   s << " public:";
44 
45   // Specialize function
46   if (parent_ != nullptr) {
47     s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
48     s << "{ return " << name_ << "View(std::move(parent)); }";
49     // CreateOptional
50     s << "static std::optional<" << name_ << "View> CreateOptional(";
51     s << parent_->name_ << "View parent)";
52     s << "{ auto to_validate = " << name_ << "View::Create(std::move(parent));";
53     s << "if (to_validate.IsValid()) { return to_validate; }";
54     s << "else {return {};}}";
55   } else {
56     s << "static " << name_ << "View Create(PacketView<";
57     s << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet)";
58     s << "{ return " << name_ << "View(std::move(packet)); }";
59     // CreateOptional
60     s << "static std::optional<" << name_ << "View> CreateOptional(PacketView<";
61     s << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet)";
62     s << "{ auto to_validate = " << name_ << "View::Create(std::move(packet));";
63     s << "if (to_validate.IsValid()) { return to_validate; }";
64     s << "else {return {};}}";
65   }
66 
67   if (generate_fuzzing || generate_tests) {
68     GenTestingParserFromBytes(s);
69   }
70 
71   std::set<std::string> fixed_types = {
72           FixedScalarField::kFieldType,
73           FixedEnumField::kFieldType,
74   };
75 
76   // Print all of the public fields which are all the fields minus the fixed fields.
77   const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
78   bool has_fixed_fields = public_fields.size() != fields_.size();
79   for (const auto& field : public_fields) {
80     GenParserFieldGetter(s, field);
81     s << "\n";
82   }
83   GenValidator(s);
84   s << "\n";
85 
86   s << " public:";
87   GenParserToString(s);
88   s << "\n";
89 
90   s << " protected:\n";
91   // Constructor from a View
92   if (parent_ != nullptr) {
93     s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
94     s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
95   } else {
96     s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!")
97       << "kLittleEndian> packet) ";
98     s << " : PacketView<" << (is_little_endian_ ? "" : "!")
99       << "kLittleEndian>(packet) { was_validated_ = false;}";
100   }
101 
102   // Print the private fields which are the fixed fields.
103   if (has_fixed_fields) {
104     const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
105     s << " private:\n";
106     for (const auto& field : private_fields) {
107       GenParserFieldGetter(s, field);
108       s << "\n";
109     }
110   }
111   s << "};\n";
112 }
113 
GenTestingParserFromBytes(std::ostream & s) const114 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
115   s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
116 
117   s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
118   s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
119   s << "return " << name_ << "View::Create(";
120   auto ancestor_ptr = parent_;
121   size_t parent_parens = 0;
122   while (ancestor_ptr != nullptr) {
123     s << ancestor_ptr->name_ << "View::Create(";
124     parent_parens++;
125     ancestor_ptr = ancestor_ptr->parent_;
126   }
127   s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
128   for (size_t i = 0; i < parent_parens; i++) {
129     s << ")";
130   }
131   s << ");";
132   s << "}";
133 
134   s << "\n#endif\n";
135 }
136 
GenParserDefinitionPybind11(std::ostream & s) const137 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
138   s << "py::class_<" << name_ << "View";
139   if (parent_ != nullptr) {
140     s << ", " << parent_->name_ << "View";
141   } else {
142     s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
143   }
144   s << ">(m, \"" << name_ << "View\")";
145   if (parent_ != nullptr) {
146     s << ".def(py::init([](" << parent_->name_ << "View parent) {";
147   } else {
148     s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!")
149       << "kLittleEndian> parent) {";
150   }
151   s << "auto view =" << name_ << "View::Create(std::move(parent));";
152   s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
153   s << "return view; }))";
154 
155   s << ".def(py::init(&" << name_ << "View::Create))";
156   std::set<std::string> protected_field_types = {
157           FixedScalarField::kFieldType,
158           FixedEnumField::kFieldType,
159           SizeField::kFieldType,
160           CountField::kFieldType,
161   };
162   const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
163   for (const auto& field : public_fields) {
164     auto getter_func_name = field->GetGetterFunctionName();
165     if (getter_func_name.empty()) {
166       continue;
167     }
168     s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
169   }
170   s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
171   s << ";\n";
172 }
173 
GenParserFieldGetter(std::ostream & s,const PacketField * field) const174 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
175   // Start field offset
176   auto start_field_offset = GetOffsetForField(field->GetName(), false);
177   auto end_field_offset = GetOffsetForField(field->GetName(), true);
178 
179   if (start_field_offset.empty() && end_field_offset.empty()) {
180     ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
181                  << "no method exists to determine field location from begin() or end().\n";
182   }
183 
184   field->GenGetter(s, start_field_offset, end_field_offset);
185 }
186 
GetDefinitionType() const187 TypeDef::Type PacketDef::GetDefinitionType() const { return TypeDef::Type::PACKET; }
188 
GenValidator(std::ostream & s) const189 void PacketDef::GenValidator(std::ostream& s) const {
190   // Get the static offset for all of our fields.
191   int bits_size = 0;
192   for (const auto& field : fields_) {
193     if (field->GetFieldType() != PaddingField::kFieldType) {
194       bits_size += field->GetSize().bits();
195     }
196   }
197 
198   // Generate the public validator IsValid().
199   // The method only needs to be generated for the top most class.
200   if (parent_ == nullptr) {
201     s << "bool IsValid() {" << std::endl;
202     s << "  if (was_validated_) {" << std::endl;
203     s << "    return true;" << std::endl;
204     s << "  } else {" << std::endl;
205     s << "    was_validated_ = true;" << std::endl;
206     s << "    return (was_validated_ = Validate());" << std::endl;
207     s << "  }" << std::endl;
208     s << "}" << std::endl;
209   }
210 
211   // Generate the private validator Validate().
212   // The method is overridden by all child classes.
213   s << "protected:" << std::endl;
214   if (parent_ == nullptr) {
215     s << "virtual bool Validate() const {" << std::endl;
216   } else {
217     s << "bool Validate() const override {" << std::endl;
218     s << "  if (!" << parent_->name_ << "View::Validate()) {" << std::endl;
219     s << "    return false;" << std::endl;
220     s << "  }" << std::endl;
221   }
222 
223   // Offset by the parents known size. We know that any dynamic fields can
224   // already be called since the parent must have already been validated by
225   // this point.
226   auto parent_size = Size(0);
227   if (parent_ != nullptr) {
228     parent_size = parent_->GetSize(true);
229   }
230 
231   s << "auto it = begin() + (" << parent_size << ") / 8;";
232 
233   // Check if you can extract the static fields.
234   // At this point you know you can use the size getters without crashing
235   // as long as they follow the instruction that size fields cant come before
236   // their corrisponding variable length field.
237   s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
238   s << "if (it > end()) return false;";
239 
240   // For any variable length fields, use their size check.
241   for (const auto& field : fields_) {
242     if (field->GetFieldType() == ChecksumStartField::kFieldType) {
243       auto offset = GetOffsetForField(field->GetName(), false);
244       if (!offset.empty()) {
245         s << "size_t sum_index = (" << offset << ") / 8;";
246       } else {
247         offset = GetOffsetForField(field->GetName(), true);
248         if (offset.empty()) {
249           ERROR(field) << "Checksum Start Field offset can not be determined.";
250         }
251         s << "size_t sum_index = size() - (" << offset << ") / 8;";
252       }
253 
254       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
255       const auto& started_field = fields_.GetField(field_name);
256       if (started_field == nullptr) {
257         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "("
258                      << field->GetName() << ")";
259       }
260       auto end_offset = GetOffsetForField(started_field->GetName(), false);
261       if (!end_offset.empty()) {
262         s << "size_t end_sum_index = (" << end_offset << ") / 8;";
263       } else {
264         end_offset = GetOffsetForField(started_field->GetName(), true);
265         if (end_offset.empty()) {
266           ERROR(started_field) << "Checksum Field end_offset can not be determined.";
267         }
268         s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset
269           << ") / 8;";
270       }
271       s << "if (end_sum_index >= size()) { return false; }";
272       if (is_little_endian_) {
273         s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
274       } else {
275         s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
276       }
277       s << started_field->GetDataType() << " checksum;";
278       s << "checksum.Initialize();";
279       s << "for (uint8_t byte : checksum_view) { ";
280       s << "checksum.AddByte(byte);}";
281       s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
282         << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
283 
284       continue;
285     }
286 
287     auto field_size = field->GetSize();
288     // Fixed size fields have already been handled.
289     if (!field_size.has_dynamic()) {
290       continue;
291     }
292 
293     // Custom fields with dynamic size must have the offset for the field passed in as well
294     // as the end iterator so that they may ensure that they don't try to read past the end.
295     // Custom fields with fixed sizes will be handled in the static offset checking.
296     if (field->GetFieldType() == CustomField::kFieldType) {
297       // Check if we can determine offset from begin(), otherwise error because by this point,
298       // the size of the custom field is unknown and can't be subtracted from end() to get the
299       // offset.
300       auto offset = GetOffsetForField(field->GetName(), false);
301       if (offset.empty()) {
302         ERROR(field) << "Custom Field offset can not be determined from begin().";
303       }
304 
305       if (offset.bits() % 8 != 0) {
306         ERROR(field) << "Custom fields must be byte aligned.";
307       }
308 
309       // Custom fields are special as their size field takes an argument.
310       const auto& custom_size_var = field->GetName() + "_size";
311       s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
312       s << "(begin() + (" << offset << ") / 8);";
313 
314       s << "if (!" << custom_size_var << ".has_value()) { return false; }";
315       s << "it += *" << custom_size_var << ";";
316       s << "if (it > end()) return false;";
317       continue;
318     } else {
319       s << "it += (" << field_size.dynamic_string() << ") / 8;";
320       s << "if (it > end()) return false;";
321     }
322   }
323 
324   // Validate constraints after validating the size
325   if (parent_constraints_.size() > 0 && parent_ == nullptr) {
326     ERROR() << "Can't have a constraint on a NULL parent";
327   }
328 
329   for (const auto& constraint : parent_constraints_) {
330     s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
331     const auto& field = parent_->GetParamList().GetField(constraint.first);
332     if (field->GetFieldType() == ScalarField::kFieldType) {
333       s << std::get<int64_t>(constraint.second);
334     } else {
335       s << std::get<std::string>(constraint.second);
336     }
337     s << ") return false;";
338   }
339 
340   // Validate the packets fields last
341   for (const auto& field : fields_) {
342     field->GenValidator(s);
343     s << "\n";
344   }
345 
346   s << "return true;";
347   s << "}\n";
348   if (parent_ == nullptr) {
349     s << "bool was_validated_{false};\n";
350   }
351 }
352 
GenParserToString(std::ostream & s) const353 void PacketDef::GenParserToString(std::ostream& s) const {
354   s << "virtual std::string ToString() const " << (parent_ != nullptr ? " override" : "") << " {";
355   s << "std::stringstream ss;";
356   s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
357 
358   if (fields_.size() > 0) {
359     s << "ss << \"\" ";
360     bool firstfield = true;
361     for (const auto& field : fields_) {
362       if (field->GetFieldType() == ReservedField::kFieldType ||
363           field->GetFieldType() == FixedScalarField::kFieldType ||
364           field->GetFieldType() == ChecksumStartField::kFieldType) {
365         continue;
366       }
367 
368       s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
369 
370       field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
371 
372       if (firstfield) {
373         firstfield = false;
374       }
375     }
376     s << ";";
377   }
378 
379   s << "ss << \" }\";";
380   s << "return ss.str();";
381   s << "}\n";
382 }
383 
GenBuilderDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const384 void PacketDef::GenBuilderDefinition(std::ostream& s, bool generate_fuzzing,
385                                      bool generate_tests) const {
386   s << "class " << name_ << "Builder";
387   if (parent_ != nullptr) {
388     s << " : public " << parent_->name_ << "Builder";
389   } else {
390     if (is_little_endian_) {
391       s << " : public PacketBuilder<kLittleEndian>";
392     } else {
393       s << " : public PacketBuilder<!kLittleEndian>";
394     }
395   }
396   s << " {";
397   s << " public:";
398   s << "  virtual ~" << name_ << "Builder() = default;";
399 
400   if (!fields_.HasBody()) {
401     GenBuilderCreate(s);
402     s << "\n";
403 
404     if (generate_fuzzing || generate_tests) {
405       GenTestingFromView(s);
406       s << "\n";
407     }
408   }
409 
410   GenSerialize(s);
411   s << "\n";
412 
413   GenSize(s);
414   s << "\n";
415 
416   s << " protected:\n";
417   GenBuilderConstructor(s);
418   s << "\n";
419 
420   GenBuilderParameterChecker(s);
421   s << "\n";
422 
423   GenMembers(s);
424   s << "};\n";
425 
426   if (generate_tests) {
427     GenTestDefine(s);
428     s << "\n";
429   }
430 
431   if (generate_fuzzing || generate_tests) {
432     GenReflectTestDefine(s);
433     s << "\n";
434   }
435 
436   if (generate_fuzzing) {
437     GenFuzzTestDefine(s);
438     s << "\n";
439   }
440 }
441 
GenTestingFromView(std::ostream & s) const442 void PacketDef::GenTestingFromView(std::ostream& s) const {
443   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
444 
445   s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
446   s << "if (!view.IsValid()) return nullptr;";
447   s << "return " << name_ << "Builder::Create(";
448   FieldList params = GetParamList().GetFieldsWithoutTypes({
449           BodyField::kFieldType,
450   });
451   for (std::size_t i = 0; i < params.size(); i++) {
452     params[i]->GenBuilderParameterFromView(s);
453     if (i != params.size() - 1) {
454       s << ", ";
455     }
456   }
457   s << ");";
458   s << "}";
459 
460   s << "\n#endif\n";
461 }
462 
GenBuilderDefinitionPybind11(std::ostream & s) const463 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
464   s << "py::class_<" << name_ << "Builder";
465   if (parent_ != nullptr) {
466     s << ", " << parent_->name_ << "Builder";
467   } else {
468     if (is_little_endian_) {
469       s << ", PacketBuilder<kLittleEndian>";
470     } else {
471       s << ", PacketBuilder<!kLittleEndian>";
472     }
473   }
474   s << ", std::shared_ptr<" << name_ << "Builder>";
475   s << ">(m, \"" << name_ << "Builder\")";
476   if (!fields_.HasBody()) {
477     GenBuilderCreatePybind11(s);
478   }
479   s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
480   s << "std::vector<uint8_t> bytes;";
481   s << "BitInserter bi(bytes);";
482   s << "builder.Serialize(bi);";
483   s << "return bytes;})";
484   s << ";\n";
485 }
486 
GenTestDefine(std::ostream & s) const487 void PacketDef::GenTestDefine(std::ostream& s) const {
488   s << "#ifdef PACKET_TESTING\n";
489   s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
490   s << "class " << name_
491     << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
492   s << "public: ";
493   s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
494   s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
495   s << "if (!view.IsValid()) { log::info(\"Invalid Packet Bytes (size = {})\", view.size());";
496   s << "for (size_t i = 0; i < view.size(); i++) { log::info(\"{:5}:{:02x}\", i, *(view.begin() + "
497        "i)); }}";
498   s << "ASSERT_TRUE(view.IsValid());";
499   s << "auto packet = " << name_ << "Builder::FromView(view);";
500   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = "
501        "std::make_shared<std::vector<uint8_t>>();";
502   s << "packet_bytes->reserve(packet->size());";
503   s << "BitInserter it(*packet_bytes);";
504   s << "packet->Serialize(it);";
505   s << "ASSERT_EQ(*packet_bytes, captured_packet);";
506   s << "}";
507   s << "};";
508   s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
509   s << "CompareBytes(GetParam());";
510   s << "}";
511   s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
512   s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
513   int i = 0;
514   for (const auto& bytes : test_cases_) {
515     s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
516     s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
517     s << name_ << "_test_bytes_" << i << ",";
518     s << name_ << "_test_bytes_" << i << " + sizeof(";
519     s << name_ << "_test_bytes_" << i << ") - 1);";
520     i++;
521   }
522   if (!test_cases_.empty()) {
523     i = 0;
524     s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
525     for (auto bytes : test_cases_) {
526       if (i > 0) {
527         s << ",";
528       }
529       s << name_ << "_test_vec_" << i++;
530     }
531     s << ");";
532   }
533   s << "\n#endif";
534 }
535 
GenReflectTestDefine(std::ostream & s) const536 void PacketDef::GenReflectTestDefine(std::ostream& s) const {
537   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
538   s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
539   s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
540   s << "auto vec = std::vector<uint8_t>(data, data + size);";
541   s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
542   s << "if (!view.IsValid()) { return; }";
543   s << "auto packet = " << name_ << "Builder::FromView(view);";
544   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = "
545        "std::make_shared<std::vector<uint8_t>>();";
546   s << "packet_bytes->reserve(packet->size());";
547   s << "BitInserter it(*packet_bytes);";
548   s << "packet->Serialize(it);";
549   s << "}";
550   s << "\n#endif\n";
551 }
552 
GenFuzzTestDefine(std::ostream & s) const553 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
554   s << "#ifdef PACKET_FUZZ_TESTING\n";
555   s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
556   s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
557   s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
558   s << "public: ";
559   s << "explicit " << name_
560     << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& "
561        "fuzz_test_registry) {";
562   s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
563   s << "}}; ";
564   s << name_ << "ReflectionFuzzTestRegistrant " << name_
565     << "_reflection_fuzz_test_registrant(REGISTRY);";
566   s << "\n#endif";
567 }
568 
GetParametersToValidate() const569 FieldList PacketDef::GetParametersToValidate() const {
570   FieldList params_to_validate;
571   for (const auto& field : GetParamList()) {
572     if (field->HasParameterValidator()) {
573       params_to_validate.AppendField(field);
574     }
575   }
576   return params_to_validate;
577 }
578 
GenBuilderCreate(std::ostream & s) const579 void PacketDef::GenBuilderCreate(std::ostream& s) const {
580   s << "static std::unique_ptr<" << name_ << "Builder> Create(";
581 
582   auto params = GetParamList();
583   for (std::size_t i = 0; i < params.size(); i++) {
584     params[i]->GenBuilderParameter(s);
585     if (i != params.size() - 1) {
586       s << ", ";
587     }
588   }
589   s << ") {";
590 
591   // Call the constructor
592   s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
593 
594   params = params.GetFieldsWithoutTypes({
595           PayloadField::kFieldType,
596           BodyField::kFieldType,
597   });
598   // Add the parameters.
599   for (std::size_t i = 0; i < params.size(); i++) {
600     if (params[i]->BuilderParameterMustBeMoved()) {
601       s << "std::move(" << params[i]->GetName() << ")";
602     } else {
603       s << params[i]->GetName();
604     }
605     if (i != params.size() - 1) {
606       s << ", ";
607     }
608   }
609 
610   s << "));";
611   if (fields_.HasPayload()) {
612     s << "builder->payload_ = std::move(payload);";
613   }
614   s << "return builder;";
615   s << "}\n";
616 }
617 
GenBuilderCreatePybind11(std::ostream & s) const618 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
619   s << ".def(py::init([](";
620   auto params = GetParamList();
621   std::vector<std::string> constructor_args;
622   for (const auto& param : params) {
623     std::stringstream ss;
624     auto param_type = param->GetBuilderParameterType();
625     if (param_type.empty()) {
626       continue;
627     }
628     // Use shared_ptr instead of unique_ptr for the Python interface
629     if (param->BuilderParameterMustBeMoved()) {
630       param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
631     }
632     ss << param_type << " " << param->GetName();
633     constructor_args.push_back(ss.str());
634   }
635   s << util::StringJoin(",", constructor_args) << "){";
636 
637   // Deal with move only args
638   for (const auto& param : params) {
639     std::stringstream ss;
640     auto param_type = param->GetBuilderParameterType();
641     if (param_type.empty()) {
642       continue;
643     }
644     if (!param->BuilderParameterMustBeMoved()) {
645       continue;
646     }
647     auto move_only_param_name = param->GetName() + "_move_only";
648     s << param_type << " " << move_only_param_name << ";";
649     if (param->IsContainerField()) {
650       // Assume single layer container and copy it
651       auto struct_type = param->GetElementField()->GetDataType();
652       struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
653       struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
654       s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
655       // Serialize each struct
656       s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
657       s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
658       s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
659       s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
660       // Parse it again
661       s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName()
662         << "_bytes);";
663       s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
664       s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
665       // Push it into a new container
666       if (param->GetFieldType() == VectorField::kFieldType) {
667         s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
668       } else if (param->GetFieldType() == ArrayField::kFieldType) {
669         s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
670       } else {
671         ERROR() << param << " is not supported by Pybind11";
672       }
673       s << "}";
674     } else {
675       // Serialize the parameter and pass the bytes in a RawBuilder
676       s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
677       s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
678       s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
679       s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
680       s << move_only_param_name << " = ";
681       s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
682     }
683   }
684   s << "return " << name_ << "Builder::Create(";
685   std::vector<std::string> builder_vars;
686   for (const auto& param : params) {
687     std::stringstream ss;
688     auto param_type = param->GetBuilderParameterType();
689     if (param_type.empty()) {
690       continue;
691     }
692     auto param_name = param->GetName();
693     if (param->BuilderParameterMustBeMoved()) {
694       ss << "std::move(" << param_name << "_move_only)";
695     } else {
696       ss << param_name;
697     }
698     builder_vars.push_back(ss.str());
699   }
700   s << util::StringJoin(",", builder_vars) << ");}";
701   s << "))";
702 }
703 
GenBuilderParameterChecker(std::ostream & s) const704 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
705   FieldList params_to_validate = GetParametersToValidate();
706 
707   // Skip writing this function if there is nothing to validate.
708   if (params_to_validate.size() == 0) {
709     return;
710   }
711 
712   // Generate function arguments.
713   s << "void CheckParameterValues(";
714   for (std::size_t i = 0; i < params_to_validate.size(); i++) {
715     params_to_validate[i]->GenBuilderParameter(s);
716     if (i != params_to_validate.size() - 1) {
717       s << ", ";
718     }
719   }
720   s << ") {";
721 
722   // Check the parameters.
723   for (const auto& field : params_to_validate) {
724     field->GenParameterValidator(s);
725   }
726   s << "}\n";
727 }
728 
GenBuilderConstructor(std::ostream & s) const729 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
730   s << "explicit " << name_ << "Builder(";
731 
732   // Generate the constructor parameters.
733   auto params = GetParamList().GetFieldsWithoutTypes({
734           PayloadField::kFieldType,
735           BodyField::kFieldType,
736   });
737   for (std::size_t i = 0; i < params.size(); i++) {
738     params[i]->GenBuilderParameter(s);
739     if (i != params.size() - 1) {
740       s << ", ";
741     }
742   }
743   if (params.size() > 0 || parent_constraints_.size() > 0) {
744     s << ") :";
745   } else {
746     s << ")";
747   }
748 
749   // Get the list of parent params to call the parent constructor with.
750   FieldList parent_params;
751   if (parent_ != nullptr) {
752     // Pass parameters to the parent constructor
753     s << parent_->name_ << "Builder(";
754     parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
755             PayloadField::kFieldType,
756             BodyField::kFieldType,
757     });
758 
759     // Go through all the fields and replace constrained fields with fixed values
760     // when calling the parent constructor.
761     for (std::size_t i = 0; i < parent_params.size(); i++) {
762       const auto& field = parent_params[i];
763       const auto& constraint = parent_constraints_.find(field->GetName());
764       if (constraint != parent_constraints_.end()) {
765         if (field->GetFieldType() == ScalarField::kFieldType) {
766           s << std::get<int64_t>(constraint->second);
767         } else if (field->GetFieldType() == EnumField::kFieldType) {
768           s << std::get<std::string>(constraint->second);
769         } else {
770           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
771         }
772 
773         s << "/* " << field->GetName() << "_ */";
774       } else {
775         s << field->GetName();
776       }
777 
778       if (i != parent_params.size() - 1) {
779         s << ", ";
780       }
781     }
782     s << ") ";
783   }
784 
785   // Build a list of parameters that excludes all parent parameters.
786   FieldList saved_params;
787   for (const auto& field : params) {
788     if (parent_params.GetField(field->GetName()) == nullptr) {
789       saved_params.AppendField(field);
790     }
791   }
792   if (parent_ != nullptr && saved_params.size() > 0) {
793     s << ",";
794   }
795   for (std::size_t i = 0; i < saved_params.size(); i++) {
796     const auto& saved_param_name = saved_params[i]->GetName();
797     if (saved_params[i]->BuilderParameterMustBeMoved()) {
798       s << saved_param_name << "_(std::move(" << saved_param_name << "))";
799     } else {
800       s << saved_param_name << "_(" << saved_param_name << ")";
801     }
802     if (i != saved_params.size() - 1) {
803       s << ",";
804     }
805   }
806   s << " {";
807 
808   FieldList params_to_validate = GetParametersToValidate();
809 
810   if (params_to_validate.size() > 0) {
811     s << "CheckParameterValues(";
812     for (std::size_t i = 0; i < params_to_validate.size(); i++) {
813       s << params_to_validate[i]->GetName() << "_";
814       if (i != params_to_validate.size() - 1) {
815         s << ", ";
816       }
817     }
818     s << ");";
819   }
820 
821   s << "}\n";
822 }
823