1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <stdio.h>
18 #include <stdlib.h>
19
20 #include <fstream>
21 #include <iostream>
22 #include <map>
23 #include <set>
24 #include <stack>
25 #include <vector>
26
27 #include <google/protobuf/compiler/code_generator.h>
28 #include <google/protobuf/compiler/importer.h>
29 #include <google/protobuf/compiler/plugin.h>
30 #include <google/protobuf/io/printer.h>
31
32 #include "perfetto/ext/base/string_utils.h"
33
34 namespace protozero {
35 namespace {
36
37 using namespace google::protobuf;
38 using namespace google::protobuf::compiler;
39 using namespace google::protobuf::io;
40 using perfetto::base::SplitString;
41 using perfetto::base::StripChars;
42 using perfetto::base::StripSuffix;
43 using perfetto::base::ToUpper;
44
45 static constexpr auto TYPE_STRING = FieldDescriptor::TYPE_STRING;
46 static constexpr auto TYPE_MESSAGE = FieldDescriptor::TYPE_MESSAGE;
47 static constexpr auto TYPE_SINT32 = FieldDescriptor::TYPE_SINT32;
48 static constexpr auto TYPE_SINT64 = FieldDescriptor::TYPE_SINT64;
49
50 static const char kHeader[] =
51 "// DO NOT EDIT. Autogenerated by Perfetto cppgen_plugin\n";
52
53 class CppObjGenerator : public ::google::protobuf::compiler::CodeGenerator {
54 public:
55 CppObjGenerator();
56 ~CppObjGenerator() override;
57
58 // CodeGenerator implementation
59 bool Generate(const google::protobuf::FileDescriptor* file,
60 const std::string& options,
61 GeneratorContext* context,
62 std::string* error) const override;
63
64 private:
65 std::string GetCppType(const FieldDescriptor* field, bool constref) const;
66 std::string GetProtozeroSetter(const FieldDescriptor* field) const;
67 std::string GetPackedBuffer(const FieldDescriptor* field) const;
68 std::string GetPackedWireType(const FieldDescriptor* field) const;
69
70 void GenEnum(const EnumDescriptor*, Printer*) const;
71 void GenEnumAliases(const EnumDescriptor*, Printer*) const;
72 void GenClassDecl(const Descriptor*, Printer*) const;
73 void GenClassDef(const Descriptor*, Printer*) const;
74
GetNamespaces(const FileDescriptor * file) const75 std::vector<std::string> GetNamespaces(const FileDescriptor* file) const {
76 std::string pkg = file->package() + wrapper_namespace_;
77 return SplitString(pkg, ".");
78 }
79
80 template <typename T = Descriptor>
GetFullName(const T * msg,bool with_namespace=false) const81 std::string GetFullName(const T* msg, bool with_namespace = false) const {
82 std::string full_type;
83 full_type.append(msg->name());
84 for (const Descriptor* par = msg->containing_type(); par;
85 par = par->containing_type()) {
86 full_type.insert(0, par->name() + "_");
87 }
88 if (with_namespace) {
89 std::string prefix;
90 for (const std::string& ns : GetNamespaces(msg->file())) {
91 prefix += ns + "::";
92 }
93 full_type = prefix + full_type;
94 }
95 return full_type;
96 }
97
98 template <class T>
HasSamePackage(const T * descriptor) const99 bool HasSamePackage(const T* descriptor) const {
100 return descriptor->file()->package() == package_;
101 }
102
103 mutable std::string wrapper_namespace_;
104 mutable std::string package_;
105 };
106
107 CppObjGenerator::CppObjGenerator() = default;
108 CppObjGenerator::~CppObjGenerator() = default;
109
Generate(const google::protobuf::FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const110 bool CppObjGenerator::Generate(const google::protobuf::FileDescriptor* file,
111 const std::string& options,
112 GeneratorContext* context,
113 std::string* error) const {
114 for (const std::string& option : SplitString(options, ",")) {
115 std::vector<std::string> option_pair = SplitString(option, "=");
116 if (option_pair[0] == "wrapper_namespace") {
117 wrapper_namespace_ =
118 option_pair.size() == 2 ? "." + option_pair[1] : std::string();
119 } else {
120 *error = "Unknown plugin option: " + option_pair[0];
121 return false;
122 }
123 }
124
125 package_ = file->package();
126
127 auto get_file_name = [](const FileDescriptor* proto) {
128 return StripSuffix(proto->name(), ".proto") + ".gen";
129 };
130
131 const std::unique_ptr<ZeroCopyOutputStream> h_fstream(
132 context->Open(get_file_name(file) + ".h"));
133 const std::unique_ptr<ZeroCopyOutputStream> cc_fstream(
134 context->Open(get_file_name(file) + ".cc"));
135
136 // Variables are delimited by $.
137 Printer h_printer(h_fstream.get(), '$');
138 Printer cc_printer(cc_fstream.get(), '$');
139
140 std::string include_guard = file->package() + "_" + file->name() + "_CPP_H_";
141 include_guard = ToUpper(include_guard);
142 include_guard = StripChars(include_guard, ".-/\\", '_');
143
144 h_printer.Print(kHeader);
145 h_printer.Print("#ifndef $g$\n#define $g$\n\n", "g", include_guard);
146 h_printer.Print("#include <stdint.h>\n");
147 h_printer.Print("#include <bitset>\n");
148 h_printer.Print("#include <vector>\n");
149 h_printer.Print("#include <string>\n");
150 h_printer.Print("#include <type_traits>\n\n");
151 h_printer.Print("#include \"perfetto/protozero/cpp_message_obj.h\"\n");
152 h_printer.Print("#include \"perfetto/protozero/copyable_ptr.h\"\n");
153 h_printer.Print("#include \"perfetto/base/export.h\"\n\n");
154
155 cc_printer.Print("#include \"perfetto/protozero/gen_field_helpers.h\"\n");
156 cc_printer.Print("#include \"perfetto/protozero/message.h\"\n");
157 cc_printer.Print(
158 "#include \"perfetto/protozero/packed_repeated_fields.h\"\n");
159 cc_printer.Print("#include \"perfetto/protozero/proto_decoder.h\"\n");
160 cc_printer.Print("#include \"perfetto/protozero/scattered_heap_buffer.h\"\n");
161 cc_printer.Print(kHeader);
162 cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
163 cc_printer.Print("#pragma GCC diagnostic push\n");
164 cc_printer.Print("#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n");
165 cc_printer.Print("#endif\n");
166
167 // Generate includes for translated types of dependencies.
168
169 // Figure out the subset of imports that are used only for lazy fields. We
170 // won't emit a C++ #include for them. This code is overly aggressive at
171 // removing imports: it rules them out as soon as it sees one lazy field
172 // whose type is defined in that import. A 100% correct solution would require
173 // to check that *all* dependent types for a given import are lazy before
174 // excluding that. In practice we don't need that because we don't use imports
175 // for both lazy and non-lazy fields.
176 std::set<std::string> lazy_imports;
177 for (int m = 0; m < file->message_type_count(); m++) {
178 const Descriptor* msg = file->message_type(m);
179 for (int i = 0; i < msg->field_count(); i++) {
180 const FieldDescriptor* field = msg->field(i);
181 if (field->options().lazy()) {
182 lazy_imports.insert(field->message_type()->file()->name());
183 }
184 }
185 }
186
187 // Recursively traverse all imports and turn them into #include(s).
188 std::vector<const FileDescriptor*> imports_to_visit;
189 std::set<const FileDescriptor*> imports_visited;
190 imports_to_visit.push_back(file);
191
192 while (!imports_to_visit.empty()) {
193 const FileDescriptor* cur = imports_to_visit.back();
194 imports_to_visit.pop_back();
195 imports_visited.insert(cur);
196 std::string base_name = StripSuffix(cur->name(), ".proto");
197 cc_printer.Print("#include \"$f$.gen.h\"\n", "f", base_name);
198 for (int i = 0; i < cur->dependency_count(); i++) {
199 const FileDescriptor* dep = cur->dependency(i);
200 if (imports_visited.count(dep) || lazy_imports.count(dep->name()))
201 continue;
202 imports_to_visit.push_back(dep);
203 }
204 }
205
206 // Compute all nested types to generate forward declarations later.
207
208 std::set<const Descriptor*> all_types_seen; // All deps
209 std::set<const EnumDescriptor*> all_enums_seen;
210
211 // We track the types additionally in vectors to guarantee a stable order in
212 // the generated output.
213 std::vector<const Descriptor*> local_types; // Cur .proto file only.
214 std::vector<const Descriptor*> all_types; // All deps
215 std::vector<const EnumDescriptor*> local_enums;
216 std::vector<const EnumDescriptor*> all_enums;
217
218 auto add_enum = [&local_enums, &all_enums, &all_enums_seen,
219 &file](const EnumDescriptor* enum_desc) {
220 if (all_enums_seen.count(enum_desc))
221 return;
222 all_enums_seen.insert(enum_desc);
223 all_enums.push_back(enum_desc);
224 if (enum_desc->file() == file)
225 local_enums.push_back(enum_desc);
226 };
227
228 for (int i = 0; i < file->enum_type_count(); i++)
229 add_enum(file->enum_type(i));
230
231 std::stack<const Descriptor*> recursion_stack;
232 for (int i = 0; i < file->message_type_count(); i++)
233 recursion_stack.push(file->message_type(i));
234
235 while (!recursion_stack.empty()) {
236 const Descriptor* msg = recursion_stack.top();
237 recursion_stack.pop();
238 if (all_types_seen.count(msg))
239 continue;
240 all_types_seen.insert(msg);
241 all_types.push_back(msg);
242 if (msg->file() == file)
243 local_types.push_back(msg);
244
245 for (int i = 0; i < msg->nested_type_count(); i++)
246 recursion_stack.push(msg->nested_type(i));
247
248 for (int i = 0; i < msg->enum_type_count(); i++)
249 add_enum(msg->enum_type(i));
250
251 for (int i = 0; i < msg->field_count(); i++) {
252 const FieldDescriptor* field = msg->field(i);
253 if (field->has_default_value()) {
254 *error = "field " + field->name() +
255 ": Explicitly declared default values are not supported";
256 return false;
257 }
258 if (field->options().lazy() &&
259 (field->is_repeated() || field->type() != TYPE_MESSAGE)) {
260 *error = "[lazy=true] is supported only on non-repeated fields\n";
261 return false;
262 }
263
264 if (field->type() == TYPE_MESSAGE && !field->options().lazy())
265 recursion_stack.push(field->message_type());
266
267 if (field->type() == FieldDescriptor::TYPE_ENUM)
268 add_enum(field->enum_type());
269 }
270 } // while (!recursion_stack.empty())
271
272 // Generate forward declarations in the header for proto types.
273 // Note: do NOT add #includes to other generated headers (either .gen.h or
274 // .pbzero.h). Doing so is extremely hard to handle at the build-system level
275 // and requires propagating public_deps everywhere.
276 cc_printer.Print("\n");
277
278 // -- Begin of fwd declarations.
279
280 // Build up the map of forward declarations.
281 std::multimap<std::string /*namespace*/, std::string /*decl*/> fwd_decls;
282 enum FwdType { kClass, kEnum };
283 auto add_fwd_decl = [&fwd_decls](FwdType cpp_type,
284 const std::string& full_name) {
285 auto dot = full_name.rfind("::");
286 PERFETTO_CHECK(dot != std::string::npos);
287 auto package = full_name.substr(0, dot);
288 auto name = full_name.substr(dot + 2);
289 if (cpp_type == kClass) {
290 fwd_decls.emplace(package, "class " + name + ";");
291 } else {
292 PERFETTO_CHECK(cpp_type == kEnum);
293 fwd_decls.emplace(package, "enum " + name + " : int;");
294 }
295 };
296
297 add_fwd_decl(kClass, "protozero::Message");
298 for (const Descriptor* msg : all_types) {
299 add_fwd_decl(kClass, GetFullName(msg, true));
300 }
301 for (const EnumDescriptor* enm : all_enums) {
302 add_fwd_decl(kEnum, GetFullName(enm, true));
303 }
304
305 // Emit forward declarations grouping by package.
306 std::string last_package;
307 auto close_last_package = [&last_package, &h_printer] {
308 if (!last_package.empty()) {
309 for (const std::string& ns : SplitString(last_package, "::"))
310 h_printer.Print("} // namespace $ns$\n", "ns", ns);
311 h_printer.Print("\n");
312 }
313 };
314 for (const auto& kv : fwd_decls) {
315 const std::string& package = kv.first;
316 if (package != last_package) {
317 close_last_package();
318 last_package = package;
319 for (const std::string& ns : SplitString(package, "::"))
320 h_printer.Print("namespace $ns$ {\n", "ns", ns);
321 }
322 h_printer.Print("$decl$\n", "decl", kv.second);
323 }
324 close_last_package();
325
326 // -- End of fwd declarations.
327
328 for (const std::string& ns : GetNamespaces(file)) {
329 h_printer.Print("namespace $n$ {\n", "n", ns);
330 cc_printer.Print("namespace $n$ {\n", "n", ns);
331 }
332
333 // Generate declarations and definitions.
334 for (const EnumDescriptor* enm : local_enums)
335 GenEnum(enm, &h_printer);
336
337 for (const Descriptor* msg : local_types) {
338 GenClassDecl(msg, &h_printer);
339 GenClassDef(msg, &cc_printer);
340 }
341
342 for (const std::string& ns : GetNamespaces(file)) {
343 h_printer.Print("} // namespace $n$\n", "n", ns);
344 cc_printer.Print("} // namespace $n$\n", "n", ns);
345 }
346 cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
347 cc_printer.Print("#pragma GCC diagnostic pop\n");
348 cc_printer.Print("#endif\n");
349
350 h_printer.Print("\n#endif // $g$\n", "g", include_guard);
351
352 return true;
353 }
354
GetCppType(const FieldDescriptor * field,bool constref) const355 std::string CppObjGenerator::GetCppType(const FieldDescriptor* field,
356 bool constref) const {
357 switch (field->type()) {
358 case FieldDescriptor::TYPE_DOUBLE:
359 return "double";
360 case FieldDescriptor::TYPE_FLOAT:
361 return "float";
362 case FieldDescriptor::TYPE_FIXED32:
363 case FieldDescriptor::TYPE_UINT32:
364 return "uint32_t";
365 case FieldDescriptor::TYPE_SFIXED32:
366 case FieldDescriptor::TYPE_INT32:
367 case FieldDescriptor::TYPE_SINT32:
368 return "int32_t";
369 case FieldDescriptor::TYPE_FIXED64:
370 case FieldDescriptor::TYPE_UINT64:
371 return "uint64_t";
372 case FieldDescriptor::TYPE_SFIXED64:
373 case FieldDescriptor::TYPE_SINT64:
374 case FieldDescriptor::TYPE_INT64:
375 return "int64_t";
376 case FieldDescriptor::TYPE_BOOL:
377 return "bool";
378 case FieldDescriptor::TYPE_STRING:
379 case FieldDescriptor::TYPE_BYTES:
380 return constref ? "const std::string&" : "std::string";
381 case FieldDescriptor::TYPE_MESSAGE:
382 assert(!field->options().lazy());
383 return constref
384 ? "const " +
385 GetFullName(field->message_type(),
386 !HasSamePackage(field->message_type())) +
387 "&"
388 : GetFullName(field->message_type(),
389 !HasSamePackage(field->message_type()));
390 case FieldDescriptor::TYPE_ENUM:
391 return GetFullName(field->enum_type(),
392 !HasSamePackage(field->enum_type()));
393 case FieldDescriptor::TYPE_GROUP:
394 abort();
395 }
396 abort(); // for gcc
397 }
398
GetProtozeroSetter(const FieldDescriptor * field) const399 std::string CppObjGenerator::GetProtozeroSetter(
400 const FieldDescriptor* field) const {
401 switch (field->type()) {
402 case FieldDescriptor::TYPE_BOOL:
403 return "::protozero::internal::gen_helpers::SerializeTinyVarInt";
404 case FieldDescriptor::TYPE_INT32:
405 case FieldDescriptor::TYPE_INT64:
406 case FieldDescriptor::TYPE_UINT32:
407 case FieldDescriptor::TYPE_UINT64:
408 case FieldDescriptor::TYPE_ENUM:
409 return "::protozero::internal::gen_helpers::SerializeVarInt";
410 case FieldDescriptor::TYPE_SINT32:
411 case FieldDescriptor::TYPE_SINT64:
412 return "::protozero::internal::gen_helpers::SerializeSignedVarInt";
413 case FieldDescriptor::TYPE_FIXED32:
414 case FieldDescriptor::TYPE_FIXED64:
415 case FieldDescriptor::TYPE_SFIXED32:
416 case FieldDescriptor::TYPE_SFIXED64:
417 case FieldDescriptor::TYPE_FLOAT:
418 case FieldDescriptor::TYPE_DOUBLE:
419 return "::protozero::internal::gen_helpers::SerializeFixed";
420 case FieldDescriptor::TYPE_STRING:
421 case FieldDescriptor::TYPE_BYTES:
422 return "::protozero::internal::gen_helpers::SerializeString";
423 case FieldDescriptor::TYPE_GROUP:
424 case FieldDescriptor::TYPE_MESSAGE:
425 abort();
426 }
427 abort();
428 }
429
GetPackedBuffer(const FieldDescriptor * field) const430 std::string CppObjGenerator::GetPackedBuffer(
431 const FieldDescriptor* field) const {
432 switch (field->type()) {
433 case FieldDescriptor::TYPE_FIXED32:
434 return "::protozero::PackedFixedSizeInt<uint32_t>";
435 case FieldDescriptor::TYPE_SFIXED32:
436 return "::protozero::PackedFixedSizeInt<int32_t>";
437 case FieldDescriptor::TYPE_FIXED64:
438 return "::protozero::PackedFixedSizeInt<uint64_t>";
439 case FieldDescriptor::TYPE_SFIXED64:
440 return "::protozero::PackedFixedSizeInt<int64_t>";
441 case FieldDescriptor::TYPE_DOUBLE:
442 return "::protozero::PackedFixedSizeInt<double>";
443 case FieldDescriptor::TYPE_FLOAT:
444 return "::protozero::PackedFixedSizeInt<float>";
445 case FieldDescriptor::TYPE_INT32:
446 case FieldDescriptor::TYPE_SINT32:
447 case FieldDescriptor::TYPE_UINT32:
448 case FieldDescriptor::TYPE_INT64:
449 case FieldDescriptor::TYPE_UINT64:
450 case FieldDescriptor::TYPE_SINT64:
451 case FieldDescriptor::TYPE_BOOL:
452 case FieldDescriptor::TYPE_ENUM:
453 return "::protozero::PackedVarInt";
454 case FieldDescriptor::TYPE_STRING:
455 case FieldDescriptor::TYPE_BYTES:
456 case FieldDescriptor::TYPE_MESSAGE:
457 case FieldDescriptor::TYPE_GROUP:
458 break; // Will abort()
459 }
460 abort();
461 }
462
GetPackedWireType(const FieldDescriptor * field) const463 std::string CppObjGenerator::GetPackedWireType(
464 const FieldDescriptor* field) const {
465 switch (field->type()) {
466 case FieldDescriptor::TYPE_FIXED32:
467 case FieldDescriptor::TYPE_SFIXED32:
468 case FieldDescriptor::TYPE_FLOAT:
469 return "::protozero::proto_utils::ProtoWireType::kFixed32";
470 case FieldDescriptor::TYPE_FIXED64:
471 case FieldDescriptor::TYPE_SFIXED64:
472 case FieldDescriptor::TYPE_DOUBLE:
473 return "::protozero::proto_utils::ProtoWireType::kFixed64";
474 case FieldDescriptor::TYPE_INT32:
475 case FieldDescriptor::TYPE_SINT32:
476 case FieldDescriptor::TYPE_UINT32:
477 case FieldDescriptor::TYPE_INT64:
478 case FieldDescriptor::TYPE_UINT64:
479 case FieldDescriptor::TYPE_SINT64:
480 case FieldDescriptor::TYPE_BOOL:
481 case FieldDescriptor::TYPE_ENUM:
482 return "::protozero::proto_utils::ProtoWireType::kVarInt";
483 case FieldDescriptor::TYPE_STRING:
484 case FieldDescriptor::TYPE_BYTES:
485 case FieldDescriptor::TYPE_MESSAGE:
486 case FieldDescriptor::TYPE_GROUP:
487 break; // Will abort()
488 }
489 abort();
490 }
491
GenEnum(const EnumDescriptor * enum_desc,Printer * p) const492 void CppObjGenerator::GenEnum(const EnumDescriptor* enum_desc,
493 Printer* p) const {
494 std::string full_name = GetFullName(enum_desc);
495
496 // When generating enums, there are two cases:
497 // 1. Enums nested in a message (most frequent case), e.g.:
498 // message MyMsg { enum MyEnum { FOO=1; BAR=2; } }
499 // 2. Enum defined at the package level, outside of any message.
500 //
501 // In the case 1, the C++ code generated by the official protobuf library is:
502 // enum MyEnum { MyMsg_MyEnum_FOO=1, MyMsg_MyEnum_BAR=2 }
503 // class MyMsg { static const auto FOO = MyMsg_MyEnum_FOO; ... same for BAR }
504 //
505 // In the case 2, the C++ code is simply:
506 // enum MyEnum { FOO=1, BAR=2 }
507 // Hence this |prefix| logic.
508 std::string prefix = enum_desc->containing_type() ? full_name + "_" : "";
509 p->Print("enum $f$ : int {\n", "f", full_name);
510 for (int e = 0; e < enum_desc->value_count(); e++) {
511 const EnumValueDescriptor* value = enum_desc->value(e);
512 p->Print(" $p$$n$ = $v$,\n", "p", prefix, "n", value->name(), "v",
513 std::to_string(value->number()));
514 }
515 p->Print("};\n");
516 }
517
GenEnumAliases(const EnumDescriptor * enum_desc,Printer * p) const518 void CppObjGenerator::GenEnumAliases(const EnumDescriptor* enum_desc,
519 Printer* p) const {
520 int min_value = std::numeric_limits<int>::max();
521 int max_value = std::numeric_limits<int>::min();
522 std::string min_name;
523 std::string max_name;
524 std::string full_name = GetFullName(enum_desc);
525 for (int e = 0; e < enum_desc->value_count(); e++) {
526 const EnumValueDescriptor* value = enum_desc->value(e);
527 p->Print("static constexpr auto $n$ = $f$_$n$;\n", "f", full_name, "n",
528 value->name());
529 if (value->number() < min_value) {
530 min_value = value->number();
531 min_name = full_name + "_" + value->name();
532 }
533 if (value->number() > max_value) {
534 max_value = value->number();
535 max_name = full_name + "_" + value->name();
536 }
537 }
538 p->Print("static constexpr auto $n$_MIN = $m$;\n", "n", enum_desc->name(),
539 "m", min_name);
540 p->Print("static constexpr auto $n$_MAX = $m$;\n", "n", enum_desc->name(),
541 "m", max_name);
542 }
543
GenClassDecl(const Descriptor * msg,Printer * p) const544 void CppObjGenerator::GenClassDecl(const Descriptor* msg, Printer* p) const {
545 std::string full_name = GetFullName(msg);
546 p->Print(
547 "\nclass PERFETTO_EXPORT_COMPONENT $n$ : public "
548 "::protozero::CppMessageObj {\n",
549 "n", full_name);
550 p->Print(" public:\n");
551 p->Indent();
552
553 // Do a first pass to generate aliases for nested types.
554 // e.g., using Foo = Parent_Foo;
555 for (int i = 0; i < msg->nested_type_count(); i++) {
556 const Descriptor* nested_msg = msg->nested_type(i);
557 p->Print("using $n$ = $f$;\n", "n", nested_msg->name(), "f",
558 GetFullName(nested_msg));
559 }
560 for (int i = 0; i < msg->enum_type_count(); i++) {
561 const EnumDescriptor* nested_enum = msg->enum_type(i);
562 p->Print("using $n$ = $f$;\n", "n", nested_enum->name(), "f",
563 GetFullName(nested_enum));
564 GenEnumAliases(nested_enum, p);
565 }
566
567 // Generate constants with field numbers.
568 p->Print("enum FieldNumbers {\n");
569 for (int i = 0; i < msg->field_count(); i++) {
570 const FieldDescriptor* field = msg->field(i);
571 std::string name = field->camelcase_name();
572 name[0] = perfetto::base::Uppercase(name[0]);
573 p->Print(" k$n$FieldNumber = $num$,\n", "n", name, "num",
574 std::to_string(field->number()));
575 }
576 p->Print("};\n\n");
577
578 p->Print("$n$();\n", "n", full_name);
579 p->Print("~$n$() override;\n", "n", full_name);
580 p->Print("$n$($n$&&) noexcept;\n", "n", full_name);
581 p->Print("$n$& operator=($n$&&);\n", "n", full_name);
582 p->Print("$n$(const $n$&);\n", "n", full_name);
583 p->Print("$n$& operator=(const $n$&);\n", "n", full_name);
584 p->Print("bool operator==(const $n$&) const;\n", "n", full_name);
585 p->Print(
586 "bool operator!=(const $n$& other) const { return !(*this == other); }\n",
587 "n", full_name);
588 p->Print("\n");
589
590 std::string proto_type = GetFullName(msg, true);
591 p->Print("bool ParseFromArray(const void*, size_t) override;\n");
592 p->Print("std::string SerializeAsString() const override;\n");
593 p->Print("std::vector<uint8_t> SerializeAsArray() const override;\n");
594 p->Print("void Serialize(::protozero::Message*) const;\n");
595
596 // Generate accessors.
597 for (int i = 0; i < msg->field_count(); i++) {
598 const FieldDescriptor* field = msg->field(i);
599 auto set_bit = "_has_field_.set(" + std::to_string(field->number()) + ")";
600 p->Print("\n");
601 if (field->options().lazy()) {
602 p->Print("const std::string& $n$_raw() const { return $n$_; }\n", "n",
603 field->lowercase_name());
604 p->Print(
605 "void set_$n$_raw(const std::string& raw) { $n$_ = raw; $s$; }\n",
606 "n", field->lowercase_name(), "s", set_bit);
607 } else if (!field->is_repeated()) {
608 p->Print("bool has_$n$() const { return _has_field_[$bit$]; }\n", "n",
609 field->lowercase_name(), "bit", std::to_string(field->number()));
610 if (field->type() == TYPE_MESSAGE) {
611 p->Print("$t$ $n$() const { return *$n$_; }\n", "t",
612 GetCppType(field, true), "n", field->lowercase_name());
613 p->Print("$t$* mutable_$n$() { $s$; return $n$_.get(); }\n", "t",
614 GetCppType(field, false), "n", field->lowercase_name(), "s",
615 set_bit);
616 } else {
617 p->Print("$t$ $n$() const { return $n$_; }\n", "t",
618 GetCppType(field, true), "n", field->lowercase_name());
619 p->Print("void set_$n$($t$ value) { $n$_ = value; $s$; }\n", "t",
620 GetCppType(field, true), "n", field->lowercase_name(), "s",
621 set_bit);
622 if (field->type() == FieldDescriptor::TYPE_BYTES) {
623 p->Print(
624 "void set_$n$(const void* p, size_t s) { "
625 "$n$_.assign(reinterpret_cast<const char*>(p), s); $s$; }\n",
626 "n", field->lowercase_name(), "s", set_bit);
627 }
628 }
629 } else { // is_repeated()
630 p->Print("const std::vector<$t$>& $n$() const { return $n$_; }\n", "t",
631 GetCppType(field, false), "n", field->lowercase_name());
632 p->Print("std::vector<$t$>* mutable_$n$() { return &$n$_; }\n", "t",
633 GetCppType(field, false), "n", field->lowercase_name());
634
635 // Generate accessors for repeated message types in the .cc file so that
636 // the header doesn't depend on the full definition of all nested types.
637 if (field->type() == TYPE_MESSAGE) {
638 p->Print("int $n$_size() const;\n", "t", GetCppType(field, false), "n",
639 field->lowercase_name());
640 p->Print("void clear_$n$();\n", "n", field->lowercase_name());
641 p->Print("$t$* add_$n$();\n", "t", GetCppType(field, false), "n",
642 field->lowercase_name());
643 } else { // Primitive type.
644 p->Print(
645 "int $n$_size() const { return static_cast<int>($n$_.size()); }\n",
646 "t", GetCppType(field, false), "n", field->lowercase_name());
647 p->Print("void clear_$n$() { $n$_.clear(); }\n", "n",
648 field->lowercase_name());
649 p->Print("void add_$n$($t$ value) { $n$_.emplace_back(value); }\n", "t",
650 GetCppType(field, false), "n", field->lowercase_name());
651 // TODO(primiano): this should be done only for TYPE_MESSAGE.
652 // Unfortuntely we didn't realize before and now we have a bunch of code
653 // that does: *msg->add_int_value() = 42 instead of
654 // msg->add_int_value(42).
655 p->Print(
656 "$t$* add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
657 "t", GetCppType(field, false), "n", field->lowercase_name());
658 }
659 }
660 }
661 p->Outdent();
662 p->Print("\n private:\n");
663 p->Indent();
664
665 // Generate fields.
666 int max_field_id = 1;
667 for (int i = 0; i < msg->field_count(); i++) {
668 const FieldDescriptor* field = msg->field(i);
669 max_field_id = std::max(max_field_id, field->number());
670 if (field->options().lazy()) {
671 p->Print("std::string $n$_; // [lazy=true]\n", "n",
672 field->lowercase_name());
673 } else if (!field->is_repeated()) {
674 std::string type = GetCppType(field, false);
675 if (field->type() == TYPE_MESSAGE) {
676 type = "::protozero::CopyablePtr<" + type + ">";
677 p->Print("$t$ $n$_;\n", "t", type, "n", field->lowercase_name());
678 } else {
679 p->Print("$t$ $n$_{};\n", "t", type, "n", field->lowercase_name());
680 }
681 } else { // is_repeated()
682 p->Print("std::vector<$t$> $n$_;\n", "t", GetCppType(field, false), "n",
683 field->lowercase_name());
684 }
685 }
686 p->Print("\n");
687 p->Print("// Allows to preserve unknown protobuf fields for compatibility\n");
688 p->Print("// with future versions of .proto files.\n");
689 p->Print("std::string unknown_fields_;\n");
690
691 p->Print("\nstd::bitset<$id$> _has_field_{};\n", "id",
692 std::to_string(max_field_id + 1));
693
694 p->Outdent();
695 p->Print("};\n\n");
696 }
697
GenClassDef(const Descriptor * msg,Printer * p) const698 void CppObjGenerator::GenClassDef(const Descriptor* msg, Printer* p) const {
699 p->Print("\n");
700 std::string full_name = GetFullName(msg);
701
702 p->Print("$n$::$n$() = default;\n", "n", full_name);
703 p->Print("$n$::~$n$() = default;\n", "n", full_name);
704 p->Print("$n$::$n$(const $n$&) = default;\n", "n", full_name);
705 p->Print("$n$& $n$::operator=(const $n$&) = default;\n", "n", full_name);
706 p->Print("$n$::$n$($n$&&) noexcept = default;\n", "n", full_name);
707 p->Print("$n$& $n$::operator=($n$&&) = default;\n", "n", full_name);
708
709 p->Print("\n");
710
711 // Comparison operator
712 p->Print("bool $n$::operator==(const $n$& other) const {\n", "n", full_name);
713 p->Indent();
714
715 p->Print(
716 "return ::protozero::internal::gen_helpers::EqualsField(unknown_fields_, "
717 "other.unknown_fields_)");
718 for (int i = 0; i < msg->field_count(); i++)
719 p->Print(
720 "\n && ::protozero::internal::gen_helpers::EqualsField($n$_, "
721 "other.$n$_)",
722 "n", msg->field(i)->lowercase_name());
723 p->Print(";");
724 p->Outdent();
725 p->Print("\n}\n\n");
726
727 // Accessors for repeated message fields.
728 for (int i = 0; i < msg->field_count(); i++) {
729 const FieldDescriptor* field = msg->field(i);
730 if (field->options().lazy() || !field->is_repeated() ||
731 field->type() != TYPE_MESSAGE) {
732 continue;
733 }
734 p->Print(
735 "int $c$::$n$_size() const { return static_cast<int>($n$_.size()); }\n",
736 "c", full_name, "t", GetCppType(field, false), "n",
737 field->lowercase_name());
738 p->Print("void $c$::clear_$n$() { $n$_.clear(); }\n", "c", full_name, "n",
739 field->lowercase_name());
740 p->Print(
741 "$t$* $c$::add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
742 "c", full_name, "t", GetCppType(field, false), "n",
743 field->lowercase_name());
744 }
745
746 std::string proto_type = GetFullName(msg, true);
747
748 // Generate the ParseFromArray() method definition.
749 p->Print("bool $f$::ParseFromArray(const void* raw, size_t size) {\n", "f",
750 full_name);
751 p->Indent();
752 for (int i = 0; i < msg->field_count(); i++) {
753 const FieldDescriptor* field = msg->field(i);
754 if (field->is_repeated())
755 p->Print("$n$_.clear();\n", "n", field->lowercase_name());
756 }
757 p->Print("unknown_fields_.clear();\n");
758 p->Print("bool packed_error = false;\n");
759 p->Print("\n");
760 p->Print("::protozero::ProtoDecoder dec(raw, size);\n");
761 p->Print("for (auto field = dec.ReadField(); field.valid(); ");
762 p->Print("field = dec.ReadField()) {\n");
763 p->Indent();
764 p->Print("if (field.id() < _has_field_.size()) {\n");
765 p->Print(" _has_field_.set(field.id());\n");
766 p->Print("}\n");
767 p->Print("switch (field.id()) {\n");
768 p->Indent();
769 for (int i = 0; i < msg->field_count(); i++) {
770 const FieldDescriptor* field = msg->field(i);
771 p->Print("case $id$ /* $n$ */:\n", "id", std::to_string(field->number()),
772 "n", field->lowercase_name());
773 p->Indent();
774 if (field->options().lazy()) {
775 p->Print(
776 "::protozero::internal::gen_helpers::DeserializeString(field, "
777 "&$n$_);\n",
778 "n", field->lowercase_name());
779 } else {
780 std::string statement;
781 if (field->type() == TYPE_MESSAGE) {
782 statement = "$rval$.ParseFromArray(field.data(), field.size());\n";
783 } else {
784 if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
785 // sint32/64 fields are special and need to be zig-zag-decoded.
786 statement = "field.get_signed(&$rval$);\n";
787 } else if (field->type() == TYPE_STRING) {
788 statement =
789 "::protozero::internal::gen_helpers::DeserializeString(field, "
790 "&$rval$);\n";
791 } else {
792 statement = "field.get(&$rval$);\n";
793 }
794 }
795 if (field->is_packed()) {
796 PERFETTO_CHECK(field->is_repeated());
797 if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
798 PERFETTO_FATAL("packed signed (zigzag) fields are not supported");
799 }
800 p->Print(
801 "if "
802 "(!::protozero::internal::gen_helpers::DeserializePackedRepeated"
803 "<$w$, $c$>(field, &$n$_)) {\n",
804 "w", GetPackedWireType(field), "c", GetCppType(field, false), "n",
805 field->lowercase_name());
806 p->Print(" packed_error = true;");
807 p->Print("}\n");
808 } else if (field->is_repeated()) {
809 p->Print("$n$_.emplace_back();\n", "n", field->lowercase_name());
810 p->Print(statement.c_str(), "rval",
811 field->lowercase_name() + "_.back()");
812 } else if (field->type() == TYPE_MESSAGE) {
813 p->Print(statement.c_str(), "rval",
814 "(*" + field->lowercase_name() + "_)");
815 } else {
816 p->Print(statement.c_str(), "rval", field->lowercase_name() + "_");
817 }
818 }
819 p->Print("break;\n");
820 p->Outdent();
821 } // for (field)
822 p->Print("default:\n");
823 p->Print(" field.SerializeAndAppendTo(&unknown_fields_);\n");
824 p->Print(" break;\n");
825 p->Outdent();
826 p->Print("}\n"); // switch(field.id)
827 p->Outdent();
828 p->Print("}\n"); // for(field)
829 p->Print("return !packed_error && !dec.bytes_left();\n"); // for(field)
830 p->Outdent();
831 p->Print("}\n\n");
832
833 // Generate the SerializeAsString() method definition.
834 p->Print("std::string $f$::SerializeAsString() const {\n", "f", full_name);
835 p->Indent();
836 p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
837 p->Print("Serialize(msg.get());\n");
838 p->Print("return msg.SerializeAsString();\n");
839 p->Outdent();
840 p->Print("}\n\n");
841
842 // Generate the SerializeAsArray() method definition.
843 p->Print("std::vector<uint8_t> $f$::SerializeAsArray() const {\n", "f",
844 full_name);
845 p->Indent();
846 p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
847 p->Print("Serialize(msg.get());\n");
848 p->Print("return msg.SerializeAsArray();\n");
849 p->Outdent();
850 p->Print("}\n\n");
851
852 // Generate the Serialize() method that writes the fields into the passed
853 // protozero |msg| write-only interface |msg|.
854 p->Print("void $f$::Serialize(::protozero::Message* msg) const {\n", "f",
855 full_name);
856 p->Indent();
857 for (int i = 0; i < msg->field_count(); i++) {
858 const FieldDescriptor* field = msg->field(i);
859 std::map<std::string, std::string> args;
860 args["id"] = std::to_string(field->number());
861 args["n"] = field->lowercase_name();
862 p->Print(args, "// Field $id$: $n$\n");
863 if (field->is_packed()) {
864 PERFETTO_CHECK(field->is_repeated());
865 p->Print("{\n");
866 p->Indent();
867 p->Print("$p$ pack;\n", "p", GetPackedBuffer(field));
868 p->Print(args, "for (auto& it : $n$_)\n");
869 p->Print(args, " pack.Append(it);\n");
870 p->Print(args, "msg->AppendBytes($id$, pack.data(), pack.size());\n");
871 p->Outdent();
872 p->Print("}\n");
873 } else {
874 if (field->is_repeated()) {
875 p->Print(args, "for (auto& it : $n$_) {\n");
876 args["lvalue"] = "it";
877 args["rvalue"] = "it";
878 } else {
879 p->Print(args, "if (_has_field_[$id$]) {\n");
880 args["lvalue"] = "(*" + field->lowercase_name() + "_)";
881 args["rvalue"] = field->lowercase_name() + "_";
882 }
883 p->Indent();
884 if (field->options().lazy()) {
885 p->Print(args, "msg->AppendString($id$, $rvalue$);\n");
886 } else if (field->type() == TYPE_MESSAGE) {
887 p->Print(args,
888 "$lvalue$.Serialize("
889 "msg->BeginNestedMessage<::protozero::Message>($id$));\n");
890 } else {
891 args["setter"] = GetProtozeroSetter(field);
892 p->Print(args, "$setter$($id$, $rvalue$, msg);\n");
893 }
894 p->Outdent();
895 p->Print("}\n");
896 }
897
898 p->Print("\n");
899 } // for (field)
900 p->Print(
901 "protozero::internal::gen_helpers::SerializeUnknownFields(unknown_fields_"
902 ", msg);\n");
903 p->Outdent();
904 p->Print("}\n\n");
905 }
906
907 } // namespace
908 } // namespace protozero
909
main(int argc,char ** argv)910 int main(int argc, char** argv) {
911 ::protozero::CppObjGenerator generator;
912 return google::protobuf::compiler::PluginMain(argc, argv, &generator);
913 }
914