1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "struct_def.h"
18
19 #include "fields/all_fields.h"
20 #include "util.h"
21
StructDef(std::string name,FieldList fields)22 StructDef::StructDef(std::string name, FieldList fields) : StructDef(name, fields, nullptr) {}
StructDef(std::string name,FieldList fields,StructDef * parent)23 StructDef::StructDef(std::string name, FieldList fields, StructDef* parent)
24 : ParentDef(name, fields, parent), total_size_(GetSize(true)) {}
25
GetNewField(const std::string & name,ParseLocation loc) const26 PacketField* StructDef::GetNewField(const std::string& name, ParseLocation loc) const {
27 if (fields_.HasBody()) {
28 return new VariableLengthStructField(name, name_, loc);
29 } else {
30 return new StructField(name, name_, total_size_, loc);
31 }
32 }
33
GetDefinitionType() const34 TypeDef::Type StructDef::GetDefinitionType() const { return TypeDef::Type::STRUCT; }
35
GenSpecialize(std::ostream & s) const36 void StructDef::GenSpecialize(std::ostream& s) const {
37 if (parent_ == nullptr) {
38 return;
39 }
40 s << "static " << name_ << "* Specialize(" << parent_->name_ << "* parent) {";
41 s << "ASSERT(" << name_ << "::IsInstance(*parent));";
42 s << "return static_cast<" << name_ << "*>(parent);";
43 s << "}";
44 }
45
GenToString(std::ostream & s) const46 void StructDef::GenToString(std::ostream& s) const {
47 s << "std::string ToString() const {";
48 s << "std::stringstream ss;";
49 s << "ss << std::hex << std::showbase << \"" << name_ << " { \";";
50
51 if (fields_.size() > 0) {
52 s << "ss";
53 bool firstfield = true;
54 for (const auto& field : fields_) {
55 if (field->GetFieldType() == ReservedField::kFieldType ||
56 field->GetFieldType() == ChecksumStartField::kFieldType ||
57 field->GetFieldType() == FixedScalarField::kFieldType ||
58 field->GetFieldType() == CountField::kFieldType ||
59 field->GetFieldType() == SizeField::kFieldType) {
60 continue;
61 }
62
63 s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
64
65 field->GenStringRepresentation(s, field->GetName() + "_");
66
67 if (firstfield) {
68 firstfield = false;
69 }
70 }
71 s << ";";
72 }
73
74 s << "ss << \" }\";";
75 s << "return ss.str();";
76 s << "}\n";
77 }
78
GenParse(std::ostream & s) const79 void StructDef::GenParse(std::ostream& s) const {
80 std::string iterator =
81 (is_little_endian_ ? "Iterator<kLittleEndian>" : "Iterator<!kLittleEndian>");
82
83 if (fields_.HasBody()) {
84 s << "static std::optional<" << iterator << ">";
85 } else {
86 s << "static " << iterator;
87 }
88
89 s << " Parse(" << name_ << "* to_fill, " << iterator << " struct_begin_it ";
90
91 if (parent_ != nullptr) {
92 s << ", bool fill_parent = true) {";
93 } else {
94 s << ") {";
95 }
96 s << "auto to_bound = struct_begin_it;";
97
98 if (parent_ != nullptr) {
99 s << "if (fill_parent) {";
100 std::string parent_param = (parent_->parent_ == nullptr ? "" : ", true");
101 if (parent_->fields_.HasBody()) {
102 s << "auto parent_optional_it = " << parent_->name_ << "::Parse(to_fill, to_bound"
103 << parent_param << ");";
104 if (fields_.HasBody()) {
105 s << "if (!parent_optional_it) { return {}; }";
106 } else {
107 s << "ASSERT(parent_optional_it.has_value());";
108 }
109 } else {
110 s << parent_->name_ << "::Parse(to_fill, to_bound" << parent_param << ");";
111 }
112 s << "}";
113 }
114
115 if (!fields_.HasBody()) {
116 s << "size_t end_index = struct_begin_it.NumBytesRemaining();";
117 s << "if (end_index < " << GetSize().bytes() << ")";
118 s << "{ return struct_begin_it.Subrange(0,0);}";
119 }
120
121 Size total_bits{0};
122 for (const auto& field : fields_) {
123 if (field->GetFieldType() != ReservedField::kFieldType &&
124 field->GetFieldType() != BodyField::kFieldType &&
125 field->GetFieldType() != FixedScalarField::kFieldType &&
126 field->GetFieldType() != ChecksumStartField::kFieldType &&
127 field->GetFieldType() != ChecksumField::kFieldType &&
128 field->GetFieldType() != CountField::kFieldType) {
129 total_bits += field->GetSize().bits();
130 }
131 }
132 s << "{";
133 s << "if (to_bound.NumBytesRemaining() < " << total_bits.bytes() << ")";
134 if (!fields_.HasBody()) {
135 s << "{ return to_bound.Subrange(to_bound.NumBytesRemaining(),0);}";
136 } else {
137 s << "{ return {};}";
138 }
139 s << "}";
140 for (const auto& field : fields_) {
141 if (field->GetFieldType() != ReservedField::kFieldType &&
142 field->GetFieldType() != BodyField::kFieldType &&
143 field->GetFieldType() != FixedScalarField::kFieldType &&
144 field->GetFieldType() != SizeField::kFieldType &&
145 field->GetFieldType() != ChecksumStartField::kFieldType &&
146 field->GetFieldType() != ChecksumField::kFieldType &&
147 field->GetFieldType() != CountField::kFieldType) {
148 s << "{";
149 int num_leading_bits = field->GenBounds(s, GetStructOffsetForField(field->GetName()), Size(),
150 field->GetStructSize());
151 s << "auto " << field->GetName() << "_ptr = &to_fill->" << field->GetName() << "_;";
152 field->GenExtractor(s, num_leading_bits, true);
153 s << "}";
154 }
155 if (field->GetFieldType() == CountField::kFieldType ||
156 field->GetFieldType() == SizeField::kFieldType) {
157 s << "{";
158 int num_leading_bits = field->GenBounds(s, GetStructOffsetForField(field->GetName()), Size(),
159 field->GetStructSize());
160 s << "auto " << field->GetName() << "_ptr = &to_fill->" << field->GetName() << "_extracted_;";
161 field->GenExtractor(s, num_leading_bits, true);
162 s << "}";
163 }
164 }
165 s << "return struct_begin_it + to_fill->size();";
166 s << "}";
167 }
168
GenParseFunctionPrototype(std::ostream & s) const169 void StructDef::GenParseFunctionPrototype(std::ostream& s) const {
170 s << "std::unique_ptr<" << name_ << "> Parse" << name_ << "(";
171 if (is_little_endian_) {
172 s << "Iterator<kLittleEndian>";
173 } else {
174 s << "Iterator<!kLittleEndian>";
175 }
176 s << "it);";
177 }
178
GenDefinition(std::ostream & s) const179 void StructDef::GenDefinition(std::ostream& s) const {
180 s << "class " << name_;
181 if (parent_ != nullptr) {
182 s << " : public " << parent_->name_;
183 } else {
184 if (is_little_endian_) {
185 s << " : public PacketStruct<kLittleEndian>";
186 } else {
187 s << " : public PacketStruct<!kLittleEndian>";
188 }
189 }
190 s << " {";
191 s << " public:";
192
193 GenDefaultConstructor(s);
194 GenConstructor(s);
195
196 s << " public:\n";
197 s << " virtual ~" << name_ << "() = default;\n";
198
199 GenSerialize(s);
200 s << "\n";
201
202 GenParse(s);
203 s << "\n";
204
205 GenSize(s);
206 s << "\n";
207
208 GenInstanceOf(s);
209 s << "\n";
210
211 GenSpecialize(s);
212 s << "\n";
213
214 GenToString(s);
215 s << "\n";
216
217 GenMembers(s);
218 for (const auto& field : fields_) {
219 if (field->GetFieldType() == CountField::kFieldType ||
220 field->GetFieldType() == SizeField::kFieldType) {
221 s << "\n private:\n";
222 s << " mutable " << field->GetDataType() << " " << field->GetName() << "_extracted_{0};";
223 }
224 }
225 s << "};\n";
226
227 if (fields_.HasBody()) {
228 GenParseFunctionPrototype(s);
229 }
230 s << "\n";
231 }
232
GenDefinitionPybind11(std::ostream & s) const233 void StructDef::GenDefinitionPybind11(std::ostream& s) const {
234 s << "py::class_<" << name_;
235 if (parent_ != nullptr) {
236 s << ", " << parent_->name_;
237 } else {
238 if (is_little_endian_) {
239 s << ", PacketStruct<kLittleEndian>";
240 } else {
241 s << ", PacketStruct<!kLittleEndian>";
242 }
243 }
244 s << ", std::shared_ptr<" << name_ << ">";
245 s << ">(m, \"" << name_ << "\")";
246 s << ".def(py::init<>())";
247 s << ".def(\"Serialize\", [](" << GetTypeName() << "& obj){";
248 s << "std::vector<uint8_t> bytes;";
249 s << "BitInserter bi(bytes);";
250 s << "obj.Serialize(bi);";
251 s << "return bytes;})";
252 s << ".def(\"Parse\", &" << name_ << "::Parse)";
253 s << ".def(\"size\", &" << name_ << "::size)";
254 for (const auto& field : fields_) {
255 if (field->GetBuilderParameterType().empty()) {
256 continue;
257 }
258 s << ".def_readwrite(\"" << field->GetName() << "\", &" << name_ << "::" << field->GetName()
259 << "_)";
260 }
261 s << ";\n";
262 }
263
264 // Generate constructor which provides default values for all struct fields.
GenDefaultConstructor(std::ostream & s) const265 void StructDef::GenDefaultConstructor(std::ostream& s) const {
266 if (parent_ != nullptr) {
267 s << name_ << "(const " << parent_->name_ << "& parent) : " << parent_->name_ << "(parent) {}";
268 s << name_ << "() : " << parent_->name_ << "() {";
269 } else {
270 s << name_ << "() {";
271 }
272
273 // Get the list of parent params.
274 FieldList parent_params;
275 if (parent_ != nullptr) {
276 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
277 PayloadField::kFieldType,
278 BodyField::kFieldType,
279 });
280
281 // Set constrained parent fields to their correct values.
282 for (const auto& field : parent_params) {
283 const auto& constraint = parent_constraints_.find(field->GetName());
284 if (constraint != parent_constraints_.end()) {
285 s << parent_->name_ << "::" << field->GetName() << "_ = ";
286 if (field->GetFieldType() == ScalarField::kFieldType) {
287 s << std::get<int64_t>(constraint->second) << ";";
288 } else if (field->GetFieldType() == EnumField::kFieldType) {
289 s << std::get<std::string>(constraint->second) << ";";
290 } else {
291 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
292 }
293 }
294 }
295 }
296
297 s << "}\n";
298 }
299
300 // Generate constructor which inputs initial field values for all struct fields.
GenConstructor(std::ostream & s) const301 void StructDef::GenConstructor(std::ostream& s) const {
302 // Fetch the list of parameters and parent paremeters.
303 // GetParamList returns the list of all inherited fields that do not
304 // have a constrained value.
305 FieldList parent_params;
306 FieldList params = GetParamList().GetFieldsWithoutTypes({
307 PayloadField::kFieldType,
308 BodyField::kFieldType,
309 });
310
311 if (parent_ != nullptr) {
312 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
313 PayloadField::kFieldType,
314 BodyField::kFieldType,
315 });
316 }
317
318 // Generate constructor parameters for struct fields.
319 s << name_ << "(";
320 bool add_comma = false;
321 for (auto const& field : params) {
322 if (add_comma) {
323 s << ", ";
324 }
325 field->GenBuilderParameter(s);
326 add_comma = true;
327 }
328
329 s << ")" << std::endl;
330
331 if (params.size() > 0) {
332 s << " : ";
333 }
334
335 // Invoke parent constructor with correct field values.
336 if (parent_ != nullptr) {
337 s << parent_->name_ << "(";
338 add_comma = false;
339 for (auto const& field : parent_params) {
340 if (add_comma) {
341 s << ", ";
342 }
343
344 // Check for fields with constraint value.
345 const auto& constraint = parent_constraints_.find(field->GetName());
346 if (constraint != parent_constraints_.end()) {
347 s << "/* " << field->GetName() << " */ ";
348 if (field->GetFieldType() == ScalarField::kFieldType) {
349 s << std::get<int64_t>(constraint->second);
350 } else if (field->GetFieldType() == EnumField::kFieldType) {
351 s << std::get<std::string>(constraint->second);
352 } else {
353 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
354 }
355 } else if (field->BuilderParameterMustBeMoved()) {
356 s << "std::move(" << field->GetName() << ")";
357 } else {
358 s << field->GetName();
359 }
360
361 add_comma = true;
362 }
363
364 s << ")";
365 }
366
367 // Initialize remaining fields.
368 add_comma = parent_ != nullptr;
369 for (auto const& field : params) {
370 if (fields_.GetField(field->GetName()) == nullptr) {
371 continue;
372 }
373 if (add_comma) {
374 s << ", ";
375 }
376
377 if (field->BuilderParameterMustBeMoved()) {
378 s << field->GetName() << "_(std::move(" << field->GetName() << "))";
379 } else {
380 s << field->GetName() << "_(" << field->GetName() << ")";
381 }
382
383 add_comma = true;
384 }
385
386 s << std::endl;
387 s << "{}\n";
388 }
389
GetStructOffsetForField(std::string field_name) const390 Size StructDef::GetStructOffsetForField(std::string field_name) const {
391 auto size = Size(0);
392 for (auto it = fields_.begin(); it != fields_.end(); it++) {
393 // We've reached the field, end the loop.
394 if ((*it)->GetName() == field_name) {
395 break;
396 }
397 const auto& field = *it;
398 // When we need to parse this field, all previous fields should already be parsed.
399 if (field->GetStructSize().empty()) {
400 ERROR() << "Empty size for field " << (*it)->GetName()
401 << " finding the offset for field: " << field_name;
402 }
403 size += field->GetStructSize();
404 }
405
406 // We need the offset until a body field.
407 if (parent_ != nullptr) {
408 auto parent_body_offset = static_cast<StructDef*>(parent_)->GetStructOffsetForField("body");
409 if (parent_body_offset.empty()) {
410 ERROR() << "Empty offset for body in " << parent_->name_
411 << " finding the offset for field: " << field_name;
412 }
413 size += parent_body_offset;
414 }
415
416 return size;
417 }
418