1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "struct_def.h"
18 
19 #include "fields/all_fields.h"
20 #include "util.h"
21 
StructDef(std::string name,FieldList fields)22 StructDef::StructDef(std::string name, FieldList fields) : StructDef(name, fields, nullptr) {}
StructDef(std::string name,FieldList fields,StructDef * parent)23 StructDef::StructDef(std::string name, FieldList fields, StructDef* parent)
24     : ParentDef(name, fields, parent), total_size_(GetSize(true)) {}
25 
GetNewField(const std::string & name,ParseLocation loc) const26 PacketField* StructDef::GetNewField(const std::string& name, ParseLocation loc) const {
27   if (fields_.HasBody()) {
28     return new VariableLengthStructField(name, name_, loc);
29   } else {
30     return new StructField(name, name_, total_size_, loc);
31   }
32 }
33 
GetDefinitionType() const34 TypeDef::Type StructDef::GetDefinitionType() const { return TypeDef::Type::STRUCT; }
35 
GenSpecialize(std::ostream & s) const36 void StructDef::GenSpecialize(std::ostream& s) const {
37   if (parent_ == nullptr) {
38     return;
39   }
40   s << "static " << name_ << "* Specialize(" << parent_->name_ << "* parent) {";
41   s << "ASSERT(" << name_ << "::IsInstance(*parent));";
42   s << "return static_cast<" << name_ << "*>(parent);";
43   s << "}";
44 }
45 
GenToString(std::ostream & s) const46 void StructDef::GenToString(std::ostream& s) const {
47   s << "std::string ToString() const {";
48   s << "std::stringstream ss;";
49   s << "ss << std::hex << std::showbase << \"" << name_ << " { \";";
50 
51   if (fields_.size() > 0) {
52     s << "ss";
53     bool firstfield = true;
54     for (const auto& field : fields_) {
55       if (field->GetFieldType() == ReservedField::kFieldType ||
56           field->GetFieldType() == ChecksumStartField::kFieldType ||
57           field->GetFieldType() == FixedScalarField::kFieldType ||
58           field->GetFieldType() == CountField::kFieldType ||
59           field->GetFieldType() == SizeField::kFieldType) {
60         continue;
61       }
62 
63       s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
64 
65       field->GenStringRepresentation(s, field->GetName() + "_");
66 
67       if (firstfield) {
68         firstfield = false;
69       }
70     }
71     s << ";";
72   }
73 
74   s << "ss << \" }\";";
75   s << "return ss.str();";
76   s << "}\n";
77 }
78 
GenParse(std::ostream & s) const79 void StructDef::GenParse(std::ostream& s) const {
80   std::string iterator =
81           (is_little_endian_ ? "Iterator<kLittleEndian>" : "Iterator<!kLittleEndian>");
82 
83   if (fields_.HasBody()) {
84     s << "static std::optional<" << iterator << ">";
85   } else {
86     s << "static " << iterator;
87   }
88 
89   s << " Parse(" << name_ << "* to_fill, " << iterator << " struct_begin_it ";
90 
91   if (parent_ != nullptr) {
92     s << ", bool fill_parent = true) {";
93   } else {
94     s << ") {";
95   }
96   s << "auto to_bound = struct_begin_it;";
97 
98   if (parent_ != nullptr) {
99     s << "if (fill_parent) {";
100     std::string parent_param = (parent_->parent_ == nullptr ? "" : ", true");
101     if (parent_->fields_.HasBody()) {
102       s << "auto parent_optional_it = " << parent_->name_ << "::Parse(to_fill, to_bound"
103         << parent_param << ");";
104       if (fields_.HasBody()) {
105         s << "if (!parent_optional_it) { return {}; }";
106       } else {
107         s << "ASSERT(parent_optional_it.has_value());";
108       }
109     } else {
110       s << parent_->name_ << "::Parse(to_fill, to_bound" << parent_param << ");";
111     }
112     s << "}";
113   }
114 
115   if (!fields_.HasBody()) {
116     s << "size_t end_index = struct_begin_it.NumBytesRemaining();";
117     s << "if (end_index < " << GetSize().bytes() << ")";
118     s << "{ return struct_begin_it.Subrange(0,0);}";
119   }
120 
121   Size total_bits{0};
122   for (const auto& field : fields_) {
123     if (field->GetFieldType() != ReservedField::kFieldType &&
124         field->GetFieldType() != BodyField::kFieldType &&
125         field->GetFieldType() != FixedScalarField::kFieldType &&
126         field->GetFieldType() != ChecksumStartField::kFieldType &&
127         field->GetFieldType() != ChecksumField::kFieldType &&
128         field->GetFieldType() != CountField::kFieldType) {
129       total_bits += field->GetSize().bits();
130     }
131   }
132   s << "{";
133   s << "if (to_bound.NumBytesRemaining() < " << total_bits.bytes() << ")";
134   if (!fields_.HasBody()) {
135     s << "{ return to_bound.Subrange(to_bound.NumBytesRemaining(),0);}";
136   } else {
137     s << "{ return {};}";
138   }
139   s << "}";
140   for (const auto& field : fields_) {
141     if (field->GetFieldType() != ReservedField::kFieldType &&
142         field->GetFieldType() != BodyField::kFieldType &&
143         field->GetFieldType() != FixedScalarField::kFieldType &&
144         field->GetFieldType() != SizeField::kFieldType &&
145         field->GetFieldType() != ChecksumStartField::kFieldType &&
146         field->GetFieldType() != ChecksumField::kFieldType &&
147         field->GetFieldType() != CountField::kFieldType) {
148       s << "{";
149       int num_leading_bits = field->GenBounds(s, GetStructOffsetForField(field->GetName()), Size(),
150                                               field->GetStructSize());
151       s << "auto " << field->GetName() << "_ptr = &to_fill->" << field->GetName() << "_;";
152       field->GenExtractor(s, num_leading_bits, true);
153       s << "}";
154     }
155     if (field->GetFieldType() == CountField::kFieldType ||
156         field->GetFieldType() == SizeField::kFieldType) {
157       s << "{";
158       int num_leading_bits = field->GenBounds(s, GetStructOffsetForField(field->GetName()), Size(),
159                                               field->GetStructSize());
160       s << "auto " << field->GetName() << "_ptr = &to_fill->" << field->GetName() << "_extracted_;";
161       field->GenExtractor(s, num_leading_bits, true);
162       s << "}";
163     }
164   }
165   s << "return struct_begin_it + to_fill->size();";
166   s << "}";
167 }
168 
GenParseFunctionPrototype(std::ostream & s) const169 void StructDef::GenParseFunctionPrototype(std::ostream& s) const {
170   s << "std::unique_ptr<" << name_ << "> Parse" << name_ << "(";
171   if (is_little_endian_) {
172     s << "Iterator<kLittleEndian>";
173   } else {
174     s << "Iterator<!kLittleEndian>";
175   }
176   s << "it);";
177 }
178 
GenDefinition(std::ostream & s) const179 void StructDef::GenDefinition(std::ostream& s) const {
180   s << "class " << name_;
181   if (parent_ != nullptr) {
182     s << " : public " << parent_->name_;
183   } else {
184     if (is_little_endian_) {
185       s << " : public PacketStruct<kLittleEndian>";
186     } else {
187       s << " : public PacketStruct<!kLittleEndian>";
188     }
189   }
190   s << " {";
191   s << " public:";
192 
193   GenDefaultConstructor(s);
194   GenConstructor(s);
195 
196   s << " public:\n";
197   s << "  virtual ~" << name_ << "() = default;\n";
198 
199   GenSerialize(s);
200   s << "\n";
201 
202   GenParse(s);
203   s << "\n";
204 
205   GenSize(s);
206   s << "\n";
207 
208   GenInstanceOf(s);
209   s << "\n";
210 
211   GenSpecialize(s);
212   s << "\n";
213 
214   GenToString(s);
215   s << "\n";
216 
217   GenMembers(s);
218   for (const auto& field : fields_) {
219     if (field->GetFieldType() == CountField::kFieldType ||
220         field->GetFieldType() == SizeField::kFieldType) {
221       s << "\n private:\n";
222       s << " mutable " << field->GetDataType() << " " << field->GetName() << "_extracted_{0};";
223     }
224   }
225   s << "};\n";
226 
227   if (fields_.HasBody()) {
228     GenParseFunctionPrototype(s);
229   }
230   s << "\n";
231 }
232 
GenDefinitionPybind11(std::ostream & s) const233 void StructDef::GenDefinitionPybind11(std::ostream& s) const {
234   s << "py::class_<" << name_;
235   if (parent_ != nullptr) {
236     s << ", " << parent_->name_;
237   } else {
238     if (is_little_endian_) {
239       s << ", PacketStruct<kLittleEndian>";
240     } else {
241       s << ", PacketStruct<!kLittleEndian>";
242     }
243   }
244   s << ", std::shared_ptr<" << name_ << ">";
245   s << ">(m, \"" << name_ << "\")";
246   s << ".def(py::init<>())";
247   s << ".def(\"Serialize\", [](" << GetTypeName() << "& obj){";
248   s << "std::vector<uint8_t> bytes;";
249   s << "BitInserter bi(bytes);";
250   s << "obj.Serialize(bi);";
251   s << "return bytes;})";
252   s << ".def(\"Parse\", &" << name_ << "::Parse)";
253   s << ".def(\"size\", &" << name_ << "::size)";
254   for (const auto& field : fields_) {
255     if (field->GetBuilderParameterType().empty()) {
256       continue;
257     }
258     s << ".def_readwrite(\"" << field->GetName() << "\", &" << name_ << "::" << field->GetName()
259       << "_)";
260   }
261   s << ";\n";
262 }
263 
264 // Generate constructor which provides default values for all struct fields.
GenDefaultConstructor(std::ostream & s) const265 void StructDef::GenDefaultConstructor(std::ostream& s) const {
266   if (parent_ != nullptr) {
267     s << name_ << "(const " << parent_->name_ << "& parent) : " << parent_->name_ << "(parent) {}";
268     s << name_ << "() : " << parent_->name_ << "() {";
269   } else {
270     s << name_ << "() {";
271   }
272 
273   // Get the list of parent params.
274   FieldList parent_params;
275   if (parent_ != nullptr) {
276     parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
277             PayloadField::kFieldType,
278             BodyField::kFieldType,
279     });
280 
281     // Set constrained parent fields to their correct values.
282     for (const auto& field : parent_params) {
283       const auto& constraint = parent_constraints_.find(field->GetName());
284       if (constraint != parent_constraints_.end()) {
285         s << parent_->name_ << "::" << field->GetName() << "_ = ";
286         if (field->GetFieldType() == ScalarField::kFieldType) {
287           s << std::get<int64_t>(constraint->second) << ";";
288         } else if (field->GetFieldType() == EnumField::kFieldType) {
289           s << std::get<std::string>(constraint->second) << ";";
290         } else {
291           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
292         }
293       }
294     }
295   }
296 
297   s << "}\n";
298 }
299 
300 // Generate constructor which inputs initial field values for all struct fields.
GenConstructor(std::ostream & s) const301 void StructDef::GenConstructor(std::ostream& s) const {
302   // Fetch the list of parameters and parent paremeters.
303   // GetParamList returns the list of all inherited fields that do not
304   // have a constrained value.
305   FieldList parent_params;
306   FieldList params = GetParamList().GetFieldsWithoutTypes({
307           PayloadField::kFieldType,
308           BodyField::kFieldType,
309   });
310 
311   if (parent_ != nullptr) {
312     parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
313             PayloadField::kFieldType,
314             BodyField::kFieldType,
315     });
316   }
317 
318   // Generate constructor parameters for struct fields.
319   s << name_ << "(";
320   bool add_comma = false;
321   for (auto const& field : params) {
322     if (add_comma) {
323       s << ", ";
324     }
325     field->GenBuilderParameter(s);
326     add_comma = true;
327   }
328 
329   s << ")" << std::endl;
330 
331   if (params.size() > 0) {
332     s << " : ";
333   }
334 
335   // Invoke parent constructor with correct field values.
336   if (parent_ != nullptr) {
337     s << parent_->name_ << "(";
338     add_comma = false;
339     for (auto const& field : parent_params) {
340       if (add_comma) {
341         s << ", ";
342       }
343 
344       // Check for fields with constraint value.
345       const auto& constraint = parent_constraints_.find(field->GetName());
346       if (constraint != parent_constraints_.end()) {
347         s << "/* " << field->GetName() << " */ ";
348         if (field->GetFieldType() == ScalarField::kFieldType) {
349           s << std::get<int64_t>(constraint->second);
350         } else if (field->GetFieldType() == EnumField::kFieldType) {
351           s << std::get<std::string>(constraint->second);
352         } else {
353           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
354         }
355       } else if (field->BuilderParameterMustBeMoved()) {
356         s << "std::move(" << field->GetName() << ")";
357       } else {
358         s << field->GetName();
359       }
360 
361       add_comma = true;
362     }
363 
364     s << ")";
365   }
366 
367   // Initialize remaining fields.
368   add_comma = parent_ != nullptr;
369   for (auto const& field : params) {
370     if (fields_.GetField(field->GetName()) == nullptr) {
371       continue;
372     }
373     if (add_comma) {
374       s << ", ";
375     }
376 
377     if (field->BuilderParameterMustBeMoved()) {
378       s << field->GetName() << "_(std::move(" << field->GetName() << "))";
379     } else {
380       s << field->GetName() << "_(" << field->GetName() << ")";
381     }
382 
383     add_comma = true;
384   }
385 
386   s << std::endl;
387   s << "{}\n";
388 }
389 
GetStructOffsetForField(std::string field_name) const390 Size StructDef::GetStructOffsetForField(std::string field_name) const {
391   auto size = Size(0);
392   for (auto it = fields_.begin(); it != fields_.end(); it++) {
393     // We've reached the field, end the loop.
394     if ((*it)->GetName() == field_name) {
395       break;
396     }
397     const auto& field = *it;
398     // When we need to parse this field, all previous fields should already be parsed.
399     if (field->GetStructSize().empty()) {
400       ERROR() << "Empty size for field " << (*it)->GetName()
401               << " finding the offset for field: " << field_name;
402     }
403     size += field->GetStructSize();
404   }
405 
406   // We need the offset until a body field.
407   if (parent_ != nullptr) {
408     auto parent_body_offset = static_cast<StructDef*>(parent_)->GetStructOffsetForField("body");
409     if (parent_body_offset.empty()) {
410       ERROR() << "Empty offset for body in " << parent_->name_
411               << " finding the offset for field: " << field_name;
412     }
413     size += parent_body_offset;
414   }
415 
416   return size;
417 }
418