1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "parent_def.h"
18
19 #include "fields/all_fields.h"
20 #include "util.h"
21
ParentDef(std::string name,FieldList fields)22 ParentDef::ParentDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
ParentDef(std::string name,FieldList fields,ParentDef * parent)23 ParentDef::ParentDef(std::string name, FieldList fields, ParentDef* parent)
24 : TypeDef(name), fields_(fields), parent_(parent) {}
25
AddParentConstraint(std::string field_name,std::variant<int64_t,std::string> value)26 void ParentDef::AddParentConstraint(std::string field_name,
27 std::variant<int64_t, std::string> value) {
28 // NOTE: This could end up being very slow if there are a lot of constraints.
29 const auto& parent_params = parent_->GetParamList();
30 const auto& constrained_field = parent_params.GetField(field_name);
31 if (constrained_field == nullptr) {
32 ERROR() << "Attempting to constrain field " << field_name << " in parent " << parent_->name_
33 << ", but no such field exists.";
34 }
35
36 if (constrained_field->GetFieldType() == ScalarField::kFieldType) {
37 if (!std::holds_alternative<int64_t>(value)) {
38 ERROR(constrained_field) << "Attempting to constrain a scalar field to an enum value in "
39 << parent_->name_;
40 }
41 } else if (constrained_field->GetFieldType() == EnumField::kFieldType) {
42 if (!std::holds_alternative<std::string>(value)) {
43 ERROR(constrained_field) << "Attempting to constrain an enum field to a scalar value in "
44 << parent_->name_;
45 }
46 const auto& enum_def = static_cast<EnumField*>(constrained_field)->GetEnumDef();
47 if (!enum_def.HasEntry(std::get<std::string>(value))) {
48 ERROR(constrained_field) << "No matching enumeration \"" << std::get<std::string>(value)
49 << "\" for constraint on enum in parent " << parent_->name_ << ".";
50 }
51
52 // For enums, we have to qualify the value using the enum type name.
53 value = enum_def.GetTypeName() + "::" + std::get<std::string>(value);
54 } else {
55 ERROR(constrained_field) << "Field in parent " << parent_->name_
56 << " is not viable for constraining.";
57 }
58
59 parent_constraints_.insert(std::pair(field_name, value));
60 }
61
AddTestCase(std::string packet_bytes)62 void ParentDef::AddTestCase(std::string packet_bytes) {
63 test_cases_.insert(std::move(packet_bytes));
64 }
65
66 // Assign all size fields to their corresponding variable length fields.
67 // Will crash if
68 // - there aren't any fields that don't match up to a field.
69 // - the size field points to a fixed size field.
70 // - if the size field comes after the variable length field.
AssignSizeFields()71 void ParentDef::AssignSizeFields() {
72 for (const auto& field : fields_) {
73 DEBUG() << "field name: " << field->GetName();
74
75 if (field->GetFieldType() != SizeField::kFieldType &&
76 field->GetFieldType() != CountField::kFieldType) {
77 continue;
78 }
79
80 const SizeField* size_field = static_cast<SizeField*>(field);
81 // Check to see if a corresponding field can be found.
82 const auto& var_len_field = fields_.GetField(size_field->GetSizedFieldName());
83 if (var_len_field == nullptr) {
84 ERROR(field) << "Could not find corresponding field for size/count field.";
85 }
86
87 // Do the ordering check to ensure the size field comes before the
88 // variable length field.
89 for (auto it = fields_.begin(); *it != size_field; it++) {
90 DEBUG() << "field name: " << (*it)->GetName();
91 if (*it == var_len_field) {
92 ERROR(var_len_field, size_field)
93 << "Size/count field must come before the variable length field it describes.";
94 }
95 }
96
97 if (var_len_field->GetFieldType() == PayloadField::kFieldType) {
98 const auto& payload_field = static_cast<PayloadField*>(var_len_field);
99 payload_field->SetSizeField(size_field);
100 continue;
101 }
102
103 if (var_len_field->GetFieldType() == BodyField::kFieldType) {
104 const auto& body_field = static_cast<BodyField*>(var_len_field);
105 body_field->SetSizeField(size_field);
106 continue;
107 }
108
109 if (var_len_field->GetFieldType() == VectorField::kFieldType) {
110 const auto& vector_field = static_cast<VectorField*>(var_len_field);
111 vector_field->SetSizeField(size_field);
112 continue;
113 }
114
115 // If we've reached this point then the field wasn't a variable length field.
116 // Check to see if the field is a variable length field
117 ERROR(field, size_field) << "Can not use size/count in reference to a fixed size field.\n";
118 }
119 }
120
SetEndianness(bool is_little_endian)121 void ParentDef::SetEndianness(bool is_little_endian) { is_little_endian_ = is_little_endian; }
122
123 // Get the size. You scan specify without_payload in order to exclude payload fields as children
124 // will be overriding it.
GetSize(bool without_payload) const125 Size ParentDef::GetSize(bool without_payload) const {
126 auto size = Size(0);
127
128 for (const auto& field : fields_) {
129 if (without_payload && (field->GetFieldType() == PayloadField::kFieldType ||
130 field->GetFieldType() == BodyField::kFieldType)) {
131 continue;
132 }
133
134 // The offset to the field must be passed in as an argument for dynamically sized custom fields.
135 if (field->GetFieldType() == CustomField::kFieldType && field->GetSize().has_dynamic()) {
136 std::stringstream custom_field_size;
137
138 // Custom fields are special as their size field takes an argument.
139 custom_field_size << field->GetSize().dynamic_string() << "(begin()";
140
141 // Check if we can determine offset from begin(), otherwise error because by this point,
142 // the size of the custom field is unknown and can't be subtracted from end() to get the
143 // offset.
144 auto offset = GetOffsetForField(field->GetName(), false);
145 if (offset.empty()) {
146 ERROR(field) << "Custom Field offset can not be determined from begin().";
147 }
148
149 if (offset.bits() % 8 != 0) {
150 ERROR(field) << "Custom fields must be byte aligned.";
151 }
152 if (offset.has_bits()) {
153 custom_field_size << " + " << offset.bits() / 8;
154 }
155 if (offset.has_dynamic()) {
156 custom_field_size << " + " << offset.dynamic_string();
157 }
158 custom_field_size << ")";
159
160 size += custom_field_size.str();
161 continue;
162 }
163
164 size += field->GetSize();
165 }
166
167 if (parent_ != nullptr) {
168 size += parent_->GetSize(true);
169 }
170
171 return size;
172 }
173
174 // Get the offset until the field is reached, if there is no field
175 // returns an empty Size. from_end requests the offset to the field
176 // starting from the end() iterator. If there is a field with an unknown
177 // size along the traversal, then an empty size is returned.
GetOffsetForField(std::string field_name,bool from_end) const178 Size ParentDef::GetOffsetForField(std::string field_name, bool from_end) const {
179 // Check first if the field exists.
180 if (fields_.GetField(field_name) == nullptr) {
181 ERROR() << "Can't find a field offset for nonexistent field named: " << field_name << " in "
182 << name_;
183 }
184
185 PacketField* padded_field = nullptr;
186 {
187 PacketField* last_field = nullptr;
188 for (const auto field : fields_) {
189 if (field->GetFieldType() == PaddingField::kFieldType) {
190 padded_field = last_field;
191 }
192 last_field = field;
193 }
194 }
195
196 // We have to use a generic lambda to conditionally change iteration direction
197 // due to iterator and reverse_iterator being different types.
198 auto size_lambda = [&field_name, padded_field, from_end](auto from, auto to) -> Size {
199 auto size = Size(0);
200 for (auto it = from; it != to; it++) {
201 // We've reached the field, end the loop.
202 if ((*it)->GetName() == field_name) {
203 break;
204 }
205 const auto& field = *it;
206 // If there is a field with an unknown size before the field, return an empty Size.
207 if (field->GetSize().empty() && padded_field != field) {
208 return Size();
209 }
210 if (field != padded_field) {
211 if (!from_end || field->GetFieldType() != PaddingField::kFieldType) {
212 size += field->GetSize();
213 }
214 }
215 }
216 return size;
217 };
218
219 // Change iteration direction based on from_end.
220 auto size = Size();
221 if (from_end) {
222 size = size_lambda(fields_.rbegin(), fields_.rend());
223 } else {
224 size = size_lambda(fields_.begin(), fields_.end());
225 }
226 if (size.empty()) {
227 return size;
228 }
229
230 // We need the offset until a payload or body field.
231 if (parent_ != nullptr) {
232 if (parent_->fields_.HasPayload()) {
233 auto parent_payload_offset = parent_->GetOffsetForField("payload", from_end);
234 if (parent_payload_offset.empty()) {
235 ERROR() << "Empty offset for payload in " << parent_->name_
236 << " finding the offset for field: " << field_name;
237 }
238 size += parent_payload_offset;
239 } else {
240 auto parent_body_offset = parent_->GetOffsetForField("body", from_end);
241 if (parent_body_offset.empty()) {
242 ERROR() << "Empty offset for body in " << parent_->name_
243 << " finding the offset for field: " << field_name;
244 }
245 size += parent_body_offset;
246 }
247 }
248
249 return size;
250 }
251
GetParamList() const252 FieldList ParentDef::GetParamList() const {
253 FieldList params;
254
255 std::set<std::string> param_types = {
256 ScalarField::kFieldType,
257 EnumField::kFieldType,
258 ArrayField::kFieldType,
259 VectorField::kFieldType,
260 CustomField::kFieldType,
261 StructField::kFieldType,
262 VariableLengthStructField::kFieldType,
263 PayloadField::kFieldType,
264 };
265
266 if (parent_ != nullptr) {
267 auto parent_params = parent_->GetParamList().GetFieldsWithTypes(param_types);
268
269 // Do not include constrained fields in the params
270 for (const auto& field : parent_params) {
271 if (parent_constraints_.find(field->GetName()) == parent_constraints_.end()) {
272 params.AppendField(field);
273 }
274 }
275 }
276 // Add our parameters.
277 return params.Merge(fields_.GetFieldsWithTypes(param_types));
278 }
279
GenMembers(std::ostream & s) const280 void ParentDef::GenMembers(std::ostream& s) const {
281 // Add the parameter list.
282 for (const auto& field : fields_) {
283 if (field->GenBuilderMember(s)) {
284 s << "_{};";
285 }
286 }
287 }
288
GenSize(std::ostream & s) const289 void ParentDef::GenSize(std::ostream& s) const {
290 auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
291 auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
292
293 Size padded_size;
294 const PacketField* padded_field = nullptr;
295 const PacketField* last_field = nullptr;
296 for (const auto& field : fields_) {
297 if (field->GetFieldType() == PaddingField::kFieldType) {
298 if (!padded_size.empty()) {
299 ERROR() << "Only one padding field is allowed. Second field: " << field->GetName();
300 }
301 padded_field = last_field;
302 padded_size = field->GetSize();
303 }
304 last_field = field;
305 }
306
307 s << "protected:";
308 s << "size_t BitsOfHeader() const {";
309 s << "return 0";
310
311 if (parent_ != nullptr) {
312 if (parent_->GetDefinitionType() == Type::PACKET) {
313 s << " + " << parent_->name_ << "Builder::BitsOfHeader() ";
314 } else {
315 s << " + " << parent_->name_ << "::BitsOfHeader() ";
316 }
317 }
318
319 for (const auto& field : header_fields) {
320 if (field == padded_field) {
321 s << " + " << padded_size;
322 } else {
323 s << " + " << field->GetBuilderSize();
324 }
325 }
326 s << ";";
327
328 s << "}\n\n";
329
330 s << "size_t BitsOfFooter() const {";
331 s << "return 0";
332 for (const auto& field : footer_fields) {
333 if (field == padded_field) {
334 s << " + " << padded_size;
335 } else {
336 s << " + " << field->GetBuilderSize();
337 }
338 }
339
340 if (parent_ != nullptr) {
341 if (parent_->GetDefinitionType() == Type::PACKET) {
342 s << " + " << parent_->name_ << "Builder::BitsOfFooter() ";
343 } else {
344 s << " + " << parent_->name_ << "::BitsOfFooter() ";
345 }
346 }
347 s << ";";
348 s << "}\n\n";
349
350 if (fields_.HasPayload()) {
351 s << "size_t GetPayloadSize() const {";
352 s << "if (payload_ != nullptr) {return payload_->size();}";
353 s << "else { return size() - (BitsOfHeader() + BitsOfFooter()) / 8;}";
354 s << ";}\n\n";
355 }
356
357 s << "public:";
358 s << "virtual size_t size() const override {";
359 s << "return (BitsOfHeader() / 8)";
360 if (fields_.HasPayload()) {
361 s << "+ payload_->size()";
362 }
363 if (fields_.HasBody()) {
364 for (const auto& field : header_fields) {
365 if (field->GetFieldType() == SizeField::kFieldType) {
366 const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
367 if (field_name == "body") {
368 s << "+ body_size_extracted_";
369 }
370 }
371 }
372 }
373 s << " + (BitsOfFooter() / 8);";
374 s << "}\n";
375 }
376
GenSerialize(std::ostream & s) const377 void ParentDef::GenSerialize(std::ostream& s) const {
378 auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
379 auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
380
381 s << "protected:";
382 s << "void SerializeHeader(BitInserter&";
383 if (parent_ != nullptr || header_fields.size() != 0) {
384 s << " i ";
385 }
386 s << ") const {";
387
388 if (parent_ != nullptr) {
389 if (parent_->GetDefinitionType() == Type::PACKET) {
390 s << parent_->name_ << "Builder::SerializeHeader(i);";
391 } else {
392 s << parent_->name_ << "::SerializeHeader(i);";
393 }
394 }
395
396 const PacketField* padded_field = nullptr;
397 {
398 PacketField* last_field = nullptr;
399 for (const auto field : header_fields) {
400 if (field->GetFieldType() == PaddingField::kFieldType) {
401 padded_field = last_field;
402 }
403 last_field = field;
404 }
405 }
406
407 for (const auto& field : header_fields) {
408 if (field->GetFieldType() == SizeField::kFieldType) {
409 const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
410 const auto& sized_field = fields_.GetField(field_name);
411 if (sized_field == nullptr) {
412 ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
413 }
414 if (sized_field->GetFieldType() == PayloadField::kFieldType) {
415 s << "size_t payload_bytes = GetPayloadSize();";
416 std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
417 if (modifier != "") {
418 s << "payload_bytes = payload_bytes + " << modifier.substr(1) << ";";
419 }
420 s << "ASSERT(payload_bytes < (static_cast<size_t>(1) << " << field->GetSize().bits()
421 << "));";
422 s << "insert(static_cast<" << field->GetDataType() << ">(payload_bytes), i,"
423 << field->GetSize().bits() << ");";
424 } else if (sized_field->GetFieldType() == BodyField::kFieldType) {
425 s << field->GetName() << "_extracted_ = 0;";
426 s << "size_t local_size = " << name_ << "::size();";
427
428 s << "ASSERT((size() - local_size) < (static_cast<size_t>(1) << " << field->GetSize().bits()
429 << "));";
430 s << "insert(static_cast<" << field->GetDataType() << ">(size() - local_size), i,"
431 << field->GetSize().bits() << ");";
432 } else {
433 if (sized_field->GetFieldType() != VectorField::kFieldType) {
434 ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
435 }
436 const auto& vector_name = field_name + "_";
437 const VectorField* vector = (VectorField*)sized_field;
438 s << "size_t " << vector_name + "bytes = 0;";
439 if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
440 s << "for (auto elem : " << vector_name << ") {";
441 s << vector_name + "bytes += elem.size(); }";
442 } else {
443 s << vector_name + "bytes = ";
444 s << vector_name << ".size() * ((" << vector->element_size_ << ") / 8);";
445 }
446 std::string modifier = vector->GetSizeModifier();
447 if (modifier != "") {
448 s << vector_name << "bytes = ";
449 s << vector_name << "bytes + " << modifier.substr(1) << ";";
450 }
451 s << "ASSERT(" << vector_name + "bytes < (1 << " << field->GetSize().bits() << "));";
452 s << "insert(" << vector_name << "bytes, i, ";
453 s << field->GetSize().bits() << ");";
454 }
455 } else if (field->GetFieldType() == ChecksumStartField::kFieldType) {
456 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
457 const auto& started_field = fields_.GetField(field_name);
458 if (started_field == nullptr) {
459 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "("
460 << field->GetName() << ")";
461 }
462 s << "auto shared_checksum_ptr = std::make_shared<" << started_field->GetDataType() << ">();";
463 s << "shared_checksum_ptr->Initialize();";
464 s << "i.RegisterObserver(packet::ByteObserver(";
465 s << "[shared_checksum_ptr](uint8_t byte){ shared_checksum_ptr->AddByte(byte);},";
466 s << "[shared_checksum_ptr](){ return "
467 "static_cast<uint64_t>(shared_checksum_ptr->GetChecksum());}));";
468 } else if (field->GetFieldType() == PaddingField::kFieldType) {
469 s << "ASSERT(unpadded_size <= " << field->GetSize().bytes() << ");";
470 s << "size_t padding_bytes = ";
471 s << field->GetSize().bytes() << " - unpadded_size;";
472 s << "for (size_t padding = 0; padding < padding_bytes; padding++) {i.insert_byte(0);}";
473 } else if (field->GetFieldType() == CountField::kFieldType) {
474 const auto& vector_name = ((SizeField*)field)->GetSizedFieldName() + "_";
475 s << "insert(" << vector_name << ".size(), i, " << field->GetSize().bits() << ");";
476 } else {
477 if (field == padded_field) {
478 s << "size_t unpadded_size = (" << field->GetBuilderSize() << ") / 8;";
479 }
480 field->GenInserter(s);
481 }
482 }
483 s << "}\n\n";
484
485 s << "void SerializeFooter(BitInserter&";
486 if (parent_ != nullptr || footer_fields.size() != 0) {
487 s << " i ";
488 }
489 s << ") const {";
490
491 for (const auto& field : footer_fields) {
492 field->GenInserter(s);
493 }
494 if (parent_ != nullptr) {
495 if (parent_->GetDefinitionType() == Type::PACKET) {
496 s << parent_->name_ << "Builder::SerializeFooter(i);";
497 } else {
498 s << parent_->name_ << "::SerializeFooter(i);";
499 }
500 }
501 s << "}\n\n";
502
503 s << "public:";
504 s << "virtual void Serialize(BitInserter& i) const override {";
505 s << "SerializeHeader(i);";
506 if (fields_.HasPayload()) {
507 s << "payload_->Serialize(i);";
508 }
509 s << "SerializeFooter(i);";
510
511 s << "}\n";
512 }
513
GenInstanceOf(std::ostream & s) const514 void ParentDef::GenInstanceOf(std::ostream& s) const {
515 if (parent_ != nullptr && parent_constraints_.size() > 0) {
516 s << "static bool IsInstance(const " << parent_->name_ << "& parent) {";
517 // Get the list of parent params.
518 FieldList parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
519 PayloadField::kFieldType,
520 BodyField::kFieldType,
521 });
522
523 // Check if constrained parent fields are set to their correct values.
524 for (const auto& field : parent_params) {
525 const auto& constraint = parent_constraints_.find(field->GetName());
526 if (constraint != parent_constraints_.end()) {
527 s << "if (parent." << field->GetName() << "_ != ";
528 if (field->GetFieldType() == ScalarField::kFieldType) {
529 s << std::get<int64_t>(constraint->second) << ")";
530 s << "{ return false;}";
531 } else if (field->GetFieldType() == EnumField::kFieldType) {
532 s << std::get<std::string>(constraint->second) << ")";
533 s << "{ return false;}";
534 } else {
535 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
536 }
537 }
538 }
539 s << "return true;}";
540 }
541 }
542
GetRootDef() const543 const ParentDef* ParentDef::GetRootDef() const {
544 if (parent_ == nullptr) {
545 return this;
546 }
547
548 return parent_->GetRootDef();
549 }
550
GetAncestors() const551 std::vector<const ParentDef*> ParentDef::GetAncestors() const {
552 std::vector<const ParentDef*> res;
553 auto parent = parent_;
554 while (parent != nullptr) {
555 res.push_back(parent);
556 parent = parent->parent_;
557 }
558 std::reverse(res.begin(), res.end());
559 return res;
560 }
561
GetAllConstraints() const562 std::map<std::string, std::variant<int64_t, std::string>> ParentDef::GetAllConstraints() const {
563 std::map<std::string, std::variant<int64_t, std::string>> res;
564 res.insert(parent_constraints_.begin(), parent_constraints_.end());
565 for (auto parent : GetAncestors()) {
566 res.insert(parent->parent_constraints_.begin(), parent->parent_constraints_.end());
567 }
568 return res;
569 }
570
HasAncestorNamed(std::string name) const571 bool ParentDef::HasAncestorNamed(std::string name) const {
572 auto parent = parent_;
573 while (parent != nullptr) {
574 if (parent->name_ == name) {
575 return true;
576 }
577 parent = parent->parent_;
578 }
579 return false;
580 }
581
FindConstraintField() const582 std::string ParentDef::FindConstraintField() const {
583 std::string res;
584 for (const auto& child : children_) {
585 if (!child->parent_constraints_.empty()) {
586 return child->parent_constraints_.begin()->first;
587 }
588 res = child->FindConstraintField();
589 }
590 return res;
591 }
592
593 std::map<const ParentDef*, const std::variant<int64_t, std::string>>
FindDescendantsWithConstraint(std::string constraint_name) const594 ParentDef::FindDescendantsWithConstraint(std::string constraint_name) const {
595 std::map<const ParentDef*, const std::variant<int64_t, std::string>> res;
596
597 for (auto const& child : children_) {
598 auto constraint = child->parent_constraints_.find(constraint_name);
599 if (constraint != child->parent_constraints_.end()) {
600 res.insert(std::pair(child, constraint->second));
601 }
602 auto m = child->FindDescendantsWithConstraint(constraint_name);
603 res.insert(m.begin(), m.end());
604 }
605 return res;
606 }
607
FindPathToDescendant(std::string descendant) const608 std::vector<const ParentDef*> ParentDef::FindPathToDescendant(std::string descendant) const {
609 std::vector<const ParentDef*> res;
610
611 for (auto const& child : children_) {
612 auto v = child->FindPathToDescendant(descendant);
613 if (v.size() > 0) {
614 res.insert(res.begin(), v.begin(), v.end());
615 res.push_back(child);
616 }
617 if (child->name_ == descendant) {
618 res.push_back(child);
619 return res;
620 }
621 }
622 return res;
623 }
624
HasChildEnums() const625 bool ParentDef::HasChildEnums() const { return !children_.empty() || fields_.HasPayload(); }
626