1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "parent_def.h"
18 
19 #include "fields/all_fields.h"
20 #include "util.h"
21 
ParentDef(std::string name,FieldList fields)22 ParentDef::ParentDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
ParentDef(std::string name,FieldList fields,ParentDef * parent)23 ParentDef::ParentDef(std::string name, FieldList fields, ParentDef* parent)
24     : TypeDef(name), fields_(fields), parent_(parent) {}
25 
AddParentConstraint(std::string field_name,std::variant<int64_t,std::string> value)26 void ParentDef::AddParentConstraint(std::string field_name,
27                                     std::variant<int64_t, std::string> value) {
28   // NOTE: This could end up being very slow if there are a lot of constraints.
29   const auto& parent_params = parent_->GetParamList();
30   const auto& constrained_field = parent_params.GetField(field_name);
31   if (constrained_field == nullptr) {
32     ERROR() << "Attempting to constrain field " << field_name << " in parent " << parent_->name_
33             << ", but no such field exists.";
34   }
35 
36   if (constrained_field->GetFieldType() == ScalarField::kFieldType) {
37     if (!std::holds_alternative<int64_t>(value)) {
38       ERROR(constrained_field) << "Attempting to constrain a scalar field to an enum value in "
39                                << parent_->name_;
40     }
41   } else if (constrained_field->GetFieldType() == EnumField::kFieldType) {
42     if (!std::holds_alternative<std::string>(value)) {
43       ERROR(constrained_field) << "Attempting to constrain an enum field to a scalar value in "
44                                << parent_->name_;
45     }
46     const auto& enum_def = static_cast<EnumField*>(constrained_field)->GetEnumDef();
47     if (!enum_def.HasEntry(std::get<std::string>(value))) {
48       ERROR(constrained_field) << "No matching enumeration \"" << std::get<std::string>(value)
49                                << "\" for constraint on enum in parent " << parent_->name_ << ".";
50     }
51 
52     // For enums, we have to qualify the value using the enum type name.
53     value = enum_def.GetTypeName() + "::" + std::get<std::string>(value);
54   } else {
55     ERROR(constrained_field) << "Field in parent " << parent_->name_
56                              << " is not viable for constraining.";
57   }
58 
59   parent_constraints_.insert(std::pair(field_name, value));
60 }
61 
AddTestCase(std::string packet_bytes)62 void ParentDef::AddTestCase(std::string packet_bytes) {
63   test_cases_.insert(std::move(packet_bytes));
64 }
65 
66 // Assign all size fields to their corresponding variable length fields.
67 // Will crash if
68 //  - there aren't any fields that don't match up to a field.
69 //  - the size field points to a fixed size field.
70 //  - if the size field comes after the variable length field.
AssignSizeFields()71 void ParentDef::AssignSizeFields() {
72   for (const auto& field : fields_) {
73     DEBUG() << "field name: " << field->GetName();
74 
75     if (field->GetFieldType() != SizeField::kFieldType &&
76         field->GetFieldType() != CountField::kFieldType) {
77       continue;
78     }
79 
80     const SizeField* size_field = static_cast<SizeField*>(field);
81     // Check to see if a corresponding field can be found.
82     const auto& var_len_field = fields_.GetField(size_field->GetSizedFieldName());
83     if (var_len_field == nullptr) {
84       ERROR(field) << "Could not find corresponding field for size/count field.";
85     }
86 
87     // Do the ordering check to ensure the size field comes before the
88     // variable length field.
89     for (auto it = fields_.begin(); *it != size_field; it++) {
90       DEBUG() << "field name: " << (*it)->GetName();
91       if (*it == var_len_field) {
92         ERROR(var_len_field, size_field)
93                 << "Size/count field must come before the variable length field it describes.";
94       }
95     }
96 
97     if (var_len_field->GetFieldType() == PayloadField::kFieldType) {
98       const auto& payload_field = static_cast<PayloadField*>(var_len_field);
99       payload_field->SetSizeField(size_field);
100       continue;
101     }
102 
103     if (var_len_field->GetFieldType() == BodyField::kFieldType) {
104       const auto& body_field = static_cast<BodyField*>(var_len_field);
105       body_field->SetSizeField(size_field);
106       continue;
107     }
108 
109     if (var_len_field->GetFieldType() == VectorField::kFieldType) {
110       const auto& vector_field = static_cast<VectorField*>(var_len_field);
111       vector_field->SetSizeField(size_field);
112       continue;
113     }
114 
115     // If we've reached this point then the field wasn't a variable length field.
116     // Check to see if the field is a variable length field
117     ERROR(field, size_field) << "Can not use size/count in reference to a fixed size field.\n";
118   }
119 }
120 
SetEndianness(bool is_little_endian)121 void ParentDef::SetEndianness(bool is_little_endian) { is_little_endian_ = is_little_endian; }
122 
123 // Get the size. You scan specify without_payload in order to exclude payload fields as children
124 // will be overriding it.
GetSize(bool without_payload) const125 Size ParentDef::GetSize(bool without_payload) const {
126   auto size = Size(0);
127 
128   for (const auto& field : fields_) {
129     if (without_payload && (field->GetFieldType() == PayloadField::kFieldType ||
130                             field->GetFieldType() == BodyField::kFieldType)) {
131       continue;
132     }
133 
134     // The offset to the field must be passed in as an argument for dynamically sized custom fields.
135     if (field->GetFieldType() == CustomField::kFieldType && field->GetSize().has_dynamic()) {
136       std::stringstream custom_field_size;
137 
138       // Custom fields are special as their size field takes an argument.
139       custom_field_size << field->GetSize().dynamic_string() << "(begin()";
140 
141       // Check if we can determine offset from begin(), otherwise error because by this point,
142       // the size of the custom field is unknown and can't be subtracted from end() to get the
143       // offset.
144       auto offset = GetOffsetForField(field->GetName(), false);
145       if (offset.empty()) {
146         ERROR(field) << "Custom Field offset can not be determined from begin().";
147       }
148 
149       if (offset.bits() % 8 != 0) {
150         ERROR(field) << "Custom fields must be byte aligned.";
151       }
152       if (offset.has_bits()) {
153         custom_field_size << " + " << offset.bits() / 8;
154       }
155       if (offset.has_dynamic()) {
156         custom_field_size << " + " << offset.dynamic_string();
157       }
158       custom_field_size << ")";
159 
160       size += custom_field_size.str();
161       continue;
162     }
163 
164     size += field->GetSize();
165   }
166 
167   if (parent_ != nullptr) {
168     size += parent_->GetSize(true);
169   }
170 
171   return size;
172 }
173 
174 // Get the offset until the field is reached, if there is no field
175 // returns an empty Size. from_end requests the offset to the field
176 // starting from the end() iterator. If there is a field with an unknown
177 // size along the traversal, then an empty size is returned.
GetOffsetForField(std::string field_name,bool from_end) const178 Size ParentDef::GetOffsetForField(std::string field_name, bool from_end) const {
179   // Check first if the field exists.
180   if (fields_.GetField(field_name) == nullptr) {
181     ERROR() << "Can't find a field offset for nonexistent field named: " << field_name << " in "
182             << name_;
183   }
184 
185   PacketField* padded_field = nullptr;
186   {
187     PacketField* last_field = nullptr;
188     for (const auto field : fields_) {
189       if (field->GetFieldType() == PaddingField::kFieldType) {
190         padded_field = last_field;
191       }
192       last_field = field;
193     }
194   }
195 
196   // We have to use a generic lambda to conditionally change iteration direction
197   // due to iterator and reverse_iterator being different types.
198   auto size_lambda = [&field_name, padded_field, from_end](auto from, auto to) -> Size {
199     auto size = Size(0);
200     for (auto it = from; it != to; it++) {
201       // We've reached the field, end the loop.
202       if ((*it)->GetName() == field_name) {
203         break;
204       }
205       const auto& field = *it;
206       // If there is a field with an unknown size before the field, return an empty Size.
207       if (field->GetSize().empty() && padded_field != field) {
208         return Size();
209       }
210       if (field != padded_field) {
211         if (!from_end || field->GetFieldType() != PaddingField::kFieldType) {
212           size += field->GetSize();
213         }
214       }
215     }
216     return size;
217   };
218 
219   // Change iteration direction based on from_end.
220   auto size = Size();
221   if (from_end) {
222     size = size_lambda(fields_.rbegin(), fields_.rend());
223   } else {
224     size = size_lambda(fields_.begin(), fields_.end());
225   }
226   if (size.empty()) {
227     return size;
228   }
229 
230   // We need the offset until a payload or body field.
231   if (parent_ != nullptr) {
232     if (parent_->fields_.HasPayload()) {
233       auto parent_payload_offset = parent_->GetOffsetForField("payload", from_end);
234       if (parent_payload_offset.empty()) {
235         ERROR() << "Empty offset for payload in " << parent_->name_
236                 << " finding the offset for field: " << field_name;
237       }
238       size += parent_payload_offset;
239     } else {
240       auto parent_body_offset = parent_->GetOffsetForField("body", from_end);
241       if (parent_body_offset.empty()) {
242         ERROR() << "Empty offset for body in " << parent_->name_
243                 << " finding the offset for field: " << field_name;
244       }
245       size += parent_body_offset;
246     }
247   }
248 
249   return size;
250 }
251 
GetParamList() const252 FieldList ParentDef::GetParamList() const {
253   FieldList params;
254 
255   std::set<std::string> param_types = {
256           ScalarField::kFieldType,
257           EnumField::kFieldType,
258           ArrayField::kFieldType,
259           VectorField::kFieldType,
260           CustomField::kFieldType,
261           StructField::kFieldType,
262           VariableLengthStructField::kFieldType,
263           PayloadField::kFieldType,
264   };
265 
266   if (parent_ != nullptr) {
267     auto parent_params = parent_->GetParamList().GetFieldsWithTypes(param_types);
268 
269     // Do not include constrained fields in the params
270     for (const auto& field : parent_params) {
271       if (parent_constraints_.find(field->GetName()) == parent_constraints_.end()) {
272         params.AppendField(field);
273       }
274     }
275   }
276   // Add our parameters.
277   return params.Merge(fields_.GetFieldsWithTypes(param_types));
278 }
279 
GenMembers(std::ostream & s) const280 void ParentDef::GenMembers(std::ostream& s) const {
281   // Add the parameter list.
282   for (const auto& field : fields_) {
283     if (field->GenBuilderMember(s)) {
284       s << "_{};";
285     }
286   }
287 }
288 
GenSize(std::ostream & s) const289 void ParentDef::GenSize(std::ostream& s) const {
290   auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
291   auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
292 
293   Size padded_size;
294   const PacketField* padded_field = nullptr;
295   const PacketField* last_field = nullptr;
296   for (const auto& field : fields_) {
297     if (field->GetFieldType() == PaddingField::kFieldType) {
298       if (!padded_size.empty()) {
299         ERROR() << "Only one padding field is allowed.  Second field: " << field->GetName();
300       }
301       padded_field = last_field;
302       padded_size = field->GetSize();
303     }
304     last_field = field;
305   }
306 
307   s << "protected:";
308   s << "size_t BitsOfHeader() const {";
309   s << "return 0";
310 
311   if (parent_ != nullptr) {
312     if (parent_->GetDefinitionType() == Type::PACKET) {
313       s << " + " << parent_->name_ << "Builder::BitsOfHeader() ";
314     } else {
315       s << " + " << parent_->name_ << "::BitsOfHeader() ";
316     }
317   }
318 
319   for (const auto& field : header_fields) {
320     if (field == padded_field) {
321       s << " + " << padded_size;
322     } else {
323       s << " + " << field->GetBuilderSize();
324     }
325   }
326   s << ";";
327 
328   s << "}\n\n";
329 
330   s << "size_t BitsOfFooter() const {";
331   s << "return 0";
332   for (const auto& field : footer_fields) {
333     if (field == padded_field) {
334       s << " + " << padded_size;
335     } else {
336       s << " + " << field->GetBuilderSize();
337     }
338   }
339 
340   if (parent_ != nullptr) {
341     if (parent_->GetDefinitionType() == Type::PACKET) {
342       s << " + " << parent_->name_ << "Builder::BitsOfFooter() ";
343     } else {
344       s << " + " << parent_->name_ << "::BitsOfFooter() ";
345     }
346   }
347   s << ";";
348   s << "}\n\n";
349 
350   if (fields_.HasPayload()) {
351     s << "size_t GetPayloadSize() const {";
352     s << "if (payload_ != nullptr) {return payload_->size();}";
353     s << "else { return size() - (BitsOfHeader() + BitsOfFooter()) / 8;}";
354     s << ";}\n\n";
355   }
356 
357   s << "public:";
358   s << "virtual size_t size() const override {";
359   s << "return (BitsOfHeader() / 8)";
360   if (fields_.HasPayload()) {
361     s << "+ payload_->size()";
362   }
363   if (fields_.HasBody()) {
364     for (const auto& field : header_fields) {
365       if (field->GetFieldType() == SizeField::kFieldType) {
366         const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
367         if (field_name == "body") {
368           s << "+ body_size_extracted_";
369         }
370       }
371     }
372   }
373   s << " + (BitsOfFooter() / 8);";
374   s << "}\n";
375 }
376 
GenSerialize(std::ostream & s) const377 void ParentDef::GenSerialize(std::ostream& s) const {
378   auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
379   auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
380 
381   s << "protected:";
382   s << "void SerializeHeader(BitInserter&";
383   if (parent_ != nullptr || header_fields.size() != 0) {
384     s << " i ";
385   }
386   s << ") const {";
387 
388   if (parent_ != nullptr) {
389     if (parent_->GetDefinitionType() == Type::PACKET) {
390       s << parent_->name_ << "Builder::SerializeHeader(i);";
391     } else {
392       s << parent_->name_ << "::SerializeHeader(i);";
393     }
394   }
395 
396   const PacketField* padded_field = nullptr;
397   {
398     PacketField* last_field = nullptr;
399     for (const auto field : header_fields) {
400       if (field->GetFieldType() == PaddingField::kFieldType) {
401         padded_field = last_field;
402       }
403       last_field = field;
404     }
405   }
406 
407   for (const auto& field : header_fields) {
408     if (field->GetFieldType() == SizeField::kFieldType) {
409       const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
410       const auto& sized_field = fields_.GetField(field_name);
411       if (sized_field == nullptr) {
412         ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
413       }
414       if (sized_field->GetFieldType() == PayloadField::kFieldType) {
415         s << "size_t payload_bytes = GetPayloadSize();";
416         std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
417         if (modifier != "") {
418           s << "payload_bytes = payload_bytes + " << modifier.substr(1) << ";";
419         }
420         s << "ASSERT(payload_bytes < (static_cast<size_t>(1) << " << field->GetSize().bits()
421           << "));";
422         s << "insert(static_cast<" << field->GetDataType() << ">(payload_bytes), i,"
423           << field->GetSize().bits() << ");";
424       } else if (sized_field->GetFieldType() == BodyField::kFieldType) {
425         s << field->GetName() << "_extracted_ = 0;";
426         s << "size_t local_size = " << name_ << "::size();";
427 
428         s << "ASSERT((size() - local_size) < (static_cast<size_t>(1) << " << field->GetSize().bits()
429           << "));";
430         s << "insert(static_cast<" << field->GetDataType() << ">(size() - local_size), i,"
431           << field->GetSize().bits() << ");";
432       } else {
433         if (sized_field->GetFieldType() != VectorField::kFieldType) {
434           ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
435         }
436         const auto& vector_name = field_name + "_";
437         const VectorField* vector = (VectorField*)sized_field;
438         s << "size_t " << vector_name + "bytes = 0;";
439         if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
440           s << "for (auto elem : " << vector_name << ") {";
441           s << vector_name + "bytes += elem.size(); }";
442         } else {
443           s << vector_name + "bytes = ";
444           s << vector_name << ".size() * ((" << vector->element_size_ << ") / 8);";
445         }
446         std::string modifier = vector->GetSizeModifier();
447         if (modifier != "") {
448           s << vector_name << "bytes = ";
449           s << vector_name << "bytes + " << modifier.substr(1) << ";";
450         }
451         s << "ASSERT(" << vector_name + "bytes < (1 << " << field->GetSize().bits() << "));";
452         s << "insert(" << vector_name << "bytes, i, ";
453         s << field->GetSize().bits() << ");";
454       }
455     } else if (field->GetFieldType() == ChecksumStartField::kFieldType) {
456       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
457       const auto& started_field = fields_.GetField(field_name);
458       if (started_field == nullptr) {
459         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "("
460                      << field->GetName() << ")";
461       }
462       s << "auto shared_checksum_ptr = std::make_shared<" << started_field->GetDataType() << ">();";
463       s << "shared_checksum_ptr->Initialize();";
464       s << "i.RegisterObserver(packet::ByteObserver(";
465       s << "[shared_checksum_ptr](uint8_t byte){ shared_checksum_ptr->AddByte(byte);},";
466       s << "[shared_checksum_ptr](){ return "
467            "static_cast<uint64_t>(shared_checksum_ptr->GetChecksum());}));";
468     } else if (field->GetFieldType() == PaddingField::kFieldType) {
469       s << "ASSERT(unpadded_size <= " << field->GetSize().bytes() << ");";
470       s << "size_t padding_bytes = ";
471       s << field->GetSize().bytes() << " - unpadded_size;";
472       s << "for (size_t padding = 0; padding < padding_bytes; padding++) {i.insert_byte(0);}";
473     } else if (field->GetFieldType() == CountField::kFieldType) {
474       const auto& vector_name = ((SizeField*)field)->GetSizedFieldName() + "_";
475       s << "insert(" << vector_name << ".size(), i, " << field->GetSize().bits() << ");";
476     } else {
477       if (field == padded_field) {
478         s << "size_t unpadded_size = (" << field->GetBuilderSize() << ") / 8;";
479       }
480       field->GenInserter(s);
481     }
482   }
483   s << "}\n\n";
484 
485   s << "void SerializeFooter(BitInserter&";
486   if (parent_ != nullptr || footer_fields.size() != 0) {
487     s << " i ";
488   }
489   s << ") const {";
490 
491   for (const auto& field : footer_fields) {
492     field->GenInserter(s);
493   }
494   if (parent_ != nullptr) {
495     if (parent_->GetDefinitionType() == Type::PACKET) {
496       s << parent_->name_ << "Builder::SerializeFooter(i);";
497     } else {
498       s << parent_->name_ << "::SerializeFooter(i);";
499     }
500   }
501   s << "}\n\n";
502 
503   s << "public:";
504   s << "virtual void Serialize(BitInserter& i) const override {";
505   s << "SerializeHeader(i);";
506   if (fields_.HasPayload()) {
507     s << "payload_->Serialize(i);";
508   }
509   s << "SerializeFooter(i);";
510 
511   s << "}\n";
512 }
513 
GenInstanceOf(std::ostream & s) const514 void ParentDef::GenInstanceOf(std::ostream& s) const {
515   if (parent_ != nullptr && parent_constraints_.size() > 0) {
516     s << "static bool IsInstance(const " << parent_->name_ << "& parent) {";
517     // Get the list of parent params.
518     FieldList parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
519             PayloadField::kFieldType,
520             BodyField::kFieldType,
521     });
522 
523     // Check if constrained parent fields are set to their correct values.
524     for (const auto& field : parent_params) {
525       const auto& constraint = parent_constraints_.find(field->GetName());
526       if (constraint != parent_constraints_.end()) {
527         s << "if (parent." << field->GetName() << "_ != ";
528         if (field->GetFieldType() == ScalarField::kFieldType) {
529           s << std::get<int64_t>(constraint->second) << ")";
530           s << "{ return false;}";
531         } else if (field->GetFieldType() == EnumField::kFieldType) {
532           s << std::get<std::string>(constraint->second) << ")";
533           s << "{ return false;}";
534         } else {
535           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
536         }
537       }
538     }
539     s << "return true;}";
540   }
541 }
542 
GetRootDef() const543 const ParentDef* ParentDef::GetRootDef() const {
544   if (parent_ == nullptr) {
545     return this;
546   }
547 
548   return parent_->GetRootDef();
549 }
550 
GetAncestors() const551 std::vector<const ParentDef*> ParentDef::GetAncestors() const {
552   std::vector<const ParentDef*> res;
553   auto parent = parent_;
554   while (parent != nullptr) {
555     res.push_back(parent);
556     parent = parent->parent_;
557   }
558   std::reverse(res.begin(), res.end());
559   return res;
560 }
561 
GetAllConstraints() const562 std::map<std::string, std::variant<int64_t, std::string>> ParentDef::GetAllConstraints() const {
563   std::map<std::string, std::variant<int64_t, std::string>> res;
564   res.insert(parent_constraints_.begin(), parent_constraints_.end());
565   for (auto parent : GetAncestors()) {
566     res.insert(parent->parent_constraints_.begin(), parent->parent_constraints_.end());
567   }
568   return res;
569 }
570 
HasAncestorNamed(std::string name) const571 bool ParentDef::HasAncestorNamed(std::string name) const {
572   auto parent = parent_;
573   while (parent != nullptr) {
574     if (parent->name_ == name) {
575       return true;
576     }
577     parent = parent->parent_;
578   }
579   return false;
580 }
581 
FindConstraintField() const582 std::string ParentDef::FindConstraintField() const {
583   std::string res;
584   for (const auto& child : children_) {
585     if (!child->parent_constraints_.empty()) {
586       return child->parent_constraints_.begin()->first;
587     }
588     res = child->FindConstraintField();
589   }
590   return res;
591 }
592 
593 std::map<const ParentDef*, const std::variant<int64_t, std::string>>
FindDescendantsWithConstraint(std::string constraint_name) const594 ParentDef::FindDescendantsWithConstraint(std::string constraint_name) const {
595   std::map<const ParentDef*, const std::variant<int64_t, std::string>> res;
596 
597   for (auto const& child : children_) {
598     auto constraint = child->parent_constraints_.find(constraint_name);
599     if (constraint != child->parent_constraints_.end()) {
600       res.insert(std::pair(child, constraint->second));
601     }
602     auto m = child->FindDescendantsWithConstraint(constraint_name);
603     res.insert(m.begin(), m.end());
604   }
605   return res;
606 }
607 
FindPathToDescendant(std::string descendant) const608 std::vector<const ParentDef*> ParentDef::FindPathToDescendant(std::string descendant) const {
609   std::vector<const ParentDef*> res;
610 
611   for (auto const& child : children_) {
612     auto v = child->FindPathToDescendant(descendant);
613     if (v.size() > 0) {
614       res.insert(res.begin(), v.begin(), v.end());
615       res.push_back(child);
616     }
617     if (child->name_ == descendant) {
618       res.push_back(child);
619       return res;
620     }
621   }
622   return res;
623 }
624 
HasChildEnums() const625 bool ParentDef::HasChildEnums() const { return !children_.empty() || fields_.HasPayload(); }
626