xref: /aosp_15_r20/external/perfetto/src/protozero/protoc_plugin/cppgen_plugin.cc (revision 6dbdd20afdafa5e3ca9b8809fa73465d530080dc)
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdio.h>
18 #include <stdlib.h>
19 
20 #include <fstream>
21 #include <iostream>
22 #include <map>
23 #include <set>
24 #include <stack>
25 #include <vector>
26 
27 #include <google/protobuf/compiler/code_generator.h>
28 #include <google/protobuf/compiler/importer.h>
29 #include <google/protobuf/compiler/plugin.h>
30 #include <google/protobuf/io/printer.h>
31 
32 #include "perfetto/ext/base/string_utils.h"
33 
34 namespace protozero {
35 namespace {
36 
37 using namespace google::protobuf;
38 using namespace google::protobuf::compiler;
39 using namespace google::protobuf::io;
40 using perfetto::base::SplitString;
41 using perfetto::base::StripChars;
42 using perfetto::base::StripSuffix;
43 using perfetto::base::ToUpper;
44 
45 static constexpr auto TYPE_STRING = FieldDescriptor::TYPE_STRING;
46 static constexpr auto TYPE_MESSAGE = FieldDescriptor::TYPE_MESSAGE;
47 static constexpr auto TYPE_SINT32 = FieldDescriptor::TYPE_SINT32;
48 static constexpr auto TYPE_SINT64 = FieldDescriptor::TYPE_SINT64;
49 
50 static const char kHeader[] =
51     "// DO NOT EDIT. Autogenerated by Perfetto cppgen_plugin\n";
52 
53 class CppObjGenerator : public ::google::protobuf::compiler::CodeGenerator {
54  public:
55   CppObjGenerator();
56   ~CppObjGenerator() override;
57 
58   // CodeGenerator implementation
59   bool Generate(const google::protobuf::FileDescriptor* file,
60                 const std::string& options,
61                 GeneratorContext* context,
62                 std::string* error) const override;
63 
64  private:
65   std::string GetCppType(const FieldDescriptor* field, bool constref) const;
66   std::string GetProtozeroSetter(const FieldDescriptor* field) const;
67   std::string GetPackedBuffer(const FieldDescriptor* field) const;
68   std::string GetPackedWireType(const FieldDescriptor* field) const;
69 
70   void GenEnum(const EnumDescriptor*, Printer*) const;
71   void GenEnumAliases(const EnumDescriptor*, Printer*) const;
72   void GenClassDecl(const Descriptor*, Printer*) const;
73   void GenClassDef(const Descriptor*, Printer*) const;
74 
GetNamespaces(const FileDescriptor * file) const75   std::vector<std::string> GetNamespaces(const FileDescriptor* file) const {
76     std::string pkg = file->package() + wrapper_namespace_;
77     return SplitString(pkg, ".");
78   }
79 
80   template <typename T = Descriptor>
GetFullName(const T * msg,bool with_namespace=false) const81   std::string GetFullName(const T* msg, bool with_namespace = false) const {
82     std::string full_type;
83     full_type.append(msg->name());
84     for (const Descriptor* par = msg->containing_type(); par;
85          par = par->containing_type()) {
86       full_type.insert(0, par->name() + "_");
87     }
88     if (with_namespace) {
89       std::string prefix;
90       for (const std::string& ns : GetNamespaces(msg->file())) {
91         prefix += ns + "::";
92       }
93       full_type = prefix + full_type;
94     }
95     return full_type;
96   }
97 
98   template <class T>
HasSamePackage(const T * descriptor) const99   bool HasSamePackage(const T* descriptor) const {
100     return descriptor->file()->package() == package_;
101   }
102 
103   mutable std::string wrapper_namespace_;
104   mutable std::string package_;
105 };
106 
107 CppObjGenerator::CppObjGenerator() = default;
108 CppObjGenerator::~CppObjGenerator() = default;
109 
Generate(const google::protobuf::FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const110 bool CppObjGenerator::Generate(const google::protobuf::FileDescriptor* file,
111                                const std::string& options,
112                                GeneratorContext* context,
113                                std::string* error) const {
114   for (const std::string& option : SplitString(options, ",")) {
115     std::vector<std::string> option_pair = SplitString(option, "=");
116     if (option_pair[0] == "wrapper_namespace") {
117       wrapper_namespace_ =
118           option_pair.size() == 2 ? "." + option_pair[1] : std::string();
119     } else {
120       *error = "Unknown plugin option: " + option_pair[0];
121       return false;
122     }
123   }
124 
125   package_ = file->package();
126 
127   auto get_file_name = [](const FileDescriptor* proto) {
128     return StripSuffix(proto->name(), ".proto") + ".gen";
129   };
130 
131   const std::unique_ptr<ZeroCopyOutputStream> h_fstream(
132       context->Open(get_file_name(file) + ".h"));
133   const std::unique_ptr<ZeroCopyOutputStream> cc_fstream(
134       context->Open(get_file_name(file) + ".cc"));
135 
136   // Variables are delimited by $.
137   Printer h_printer(h_fstream.get(), '$');
138   Printer cc_printer(cc_fstream.get(), '$');
139 
140   std::string include_guard = file->package() + "_" + file->name() + "_CPP_H_";
141   include_guard = ToUpper(include_guard);
142   include_guard = StripChars(include_guard, ".-/\\", '_');
143 
144   h_printer.Print(kHeader);
145   h_printer.Print("#ifndef $g$\n#define $g$\n\n", "g", include_guard);
146   h_printer.Print("#include <stdint.h>\n");
147   h_printer.Print("#include <bitset>\n");
148   h_printer.Print("#include <vector>\n");
149   h_printer.Print("#include <string>\n");
150   h_printer.Print("#include <type_traits>\n\n");
151   h_printer.Print("#include \"perfetto/protozero/cpp_message_obj.h\"\n");
152   h_printer.Print("#include \"perfetto/protozero/copyable_ptr.h\"\n");
153   h_printer.Print("#include \"perfetto/base/export.h\"\n\n");
154 
155   cc_printer.Print("#include \"perfetto/protozero/gen_field_helpers.h\"\n");
156   cc_printer.Print("#include \"perfetto/protozero/message.h\"\n");
157   cc_printer.Print(
158       "#include \"perfetto/protozero/packed_repeated_fields.h\"\n");
159   cc_printer.Print("#include \"perfetto/protozero/proto_decoder.h\"\n");
160   cc_printer.Print("#include \"perfetto/protozero/scattered_heap_buffer.h\"\n");
161   cc_printer.Print(kHeader);
162   cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
163   cc_printer.Print("#pragma GCC diagnostic push\n");
164   cc_printer.Print("#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n");
165   cc_printer.Print("#endif\n");
166 
167   // Generate includes for translated types of dependencies.
168 
169   // Figure out the subset of imports that are used only for lazy fields. We
170   // won't emit a C++ #include for them. This code is overly aggressive at
171   // removing imports: it rules them out as soon as it sees one lazy field
172   // whose type is defined in that import. A 100% correct solution would require
173   // to check that *all* dependent types for a given import are lazy before
174   // excluding that. In practice we don't need that because we don't use imports
175   // for both lazy and non-lazy fields.
176   std::set<std::string> lazy_imports;
177   for (int m = 0; m < file->message_type_count(); m++) {
178     const Descriptor* msg = file->message_type(m);
179     for (int i = 0; i < msg->field_count(); i++) {
180       const FieldDescriptor* field = msg->field(i);
181       if (field->options().lazy()) {
182         lazy_imports.insert(field->message_type()->file()->name());
183       }
184     }
185   }
186 
187   // Recursively traverse all imports and turn them into #include(s).
188   std::vector<const FileDescriptor*> imports_to_visit;
189   std::set<const FileDescriptor*> imports_visited;
190   imports_to_visit.push_back(file);
191 
192   while (!imports_to_visit.empty()) {
193     const FileDescriptor* cur = imports_to_visit.back();
194     imports_to_visit.pop_back();
195     imports_visited.insert(cur);
196     std::string base_name = StripSuffix(cur->name(), ".proto");
197     cc_printer.Print("#include \"$f$.gen.h\"\n", "f", base_name);
198     for (int i = 0; i < cur->dependency_count(); i++) {
199       const FileDescriptor* dep = cur->dependency(i);
200       if (imports_visited.count(dep) || lazy_imports.count(dep->name()))
201         continue;
202       imports_to_visit.push_back(dep);
203     }
204   }
205 
206   // Compute all nested types to generate forward declarations later.
207 
208   std::set<const Descriptor*> all_types_seen;  // All deps
209   std::set<const EnumDescriptor*> all_enums_seen;
210 
211   // We track the types additionally in vectors to guarantee a stable order in
212   // the generated output.
213   std::vector<const Descriptor*> local_types;  // Cur .proto file only.
214   std::vector<const Descriptor*> all_types;    // All deps
215   std::vector<const EnumDescriptor*> local_enums;
216   std::vector<const EnumDescriptor*> all_enums;
217 
218   auto add_enum = [&local_enums, &all_enums, &all_enums_seen,
219                    &file](const EnumDescriptor* enum_desc) {
220     if (all_enums_seen.count(enum_desc))
221       return;
222     all_enums_seen.insert(enum_desc);
223     all_enums.push_back(enum_desc);
224     if (enum_desc->file() == file)
225       local_enums.push_back(enum_desc);
226   };
227 
228   for (int i = 0; i < file->enum_type_count(); i++)
229     add_enum(file->enum_type(i));
230 
231   std::stack<const Descriptor*> recursion_stack;
232   for (int i = 0; i < file->message_type_count(); i++)
233     recursion_stack.push(file->message_type(i));
234 
235   while (!recursion_stack.empty()) {
236     const Descriptor* msg = recursion_stack.top();
237     recursion_stack.pop();
238     if (all_types_seen.count(msg))
239       continue;
240     all_types_seen.insert(msg);
241     all_types.push_back(msg);
242     if (msg->file() == file)
243       local_types.push_back(msg);
244 
245     for (int i = 0; i < msg->nested_type_count(); i++)
246       recursion_stack.push(msg->nested_type(i));
247 
248     for (int i = 0; i < msg->enum_type_count(); i++)
249       add_enum(msg->enum_type(i));
250 
251     for (int i = 0; i < msg->field_count(); i++) {
252       const FieldDescriptor* field = msg->field(i);
253       if (field->has_default_value()) {
254         *error = "field " + field->name() +
255                  ": Explicitly declared default values are not supported";
256         return false;
257       }
258       if (field->options().lazy() &&
259           (field->is_repeated() || field->type() != TYPE_MESSAGE)) {
260         *error = "[lazy=true] is supported only on non-repeated fields\n";
261         return false;
262       }
263 
264       if (field->type() == TYPE_MESSAGE && !field->options().lazy())
265         recursion_stack.push(field->message_type());
266 
267       if (field->type() == FieldDescriptor::TYPE_ENUM)
268         add_enum(field->enum_type());
269     }
270   }  //  while (!recursion_stack.empty())
271 
272   // Generate forward declarations in the header for proto types.
273   // Note: do NOT add #includes to other generated headers (either .gen.h or
274   // .pbzero.h). Doing so is extremely hard to handle at the build-system level
275   // and requires propagating public_deps everywhere.
276   cc_printer.Print("\n");
277 
278   // -- Begin of fwd declarations.
279 
280   // Build up the map of forward declarations.
281   std::multimap<std::string /*namespace*/, std::string /*decl*/> fwd_decls;
282   enum FwdType { kClass, kEnum };
283   auto add_fwd_decl = [&fwd_decls](FwdType cpp_type,
284                                    const std::string& full_name) {
285     auto dot = full_name.rfind("::");
286     PERFETTO_CHECK(dot != std::string::npos);
287     auto package = full_name.substr(0, dot);
288     auto name = full_name.substr(dot + 2);
289     if (cpp_type == kClass) {
290       fwd_decls.emplace(package, "class " + name + ";");
291     } else {
292       PERFETTO_CHECK(cpp_type == kEnum);
293       fwd_decls.emplace(package, "enum " + name + " : int;");
294     }
295   };
296 
297   add_fwd_decl(kClass, "protozero::Message");
298   for (const Descriptor* msg : all_types) {
299     add_fwd_decl(kClass, GetFullName(msg, true));
300   }
301   for (const EnumDescriptor* enm : all_enums) {
302     add_fwd_decl(kEnum, GetFullName(enm, true));
303   }
304 
305   // Emit forward declarations grouping by package.
306   std::string last_package;
307   auto close_last_package = [&last_package, &h_printer] {
308     if (!last_package.empty()) {
309       for (const std::string& ns : SplitString(last_package, "::"))
310         h_printer.Print("}  // namespace $ns$\n", "ns", ns);
311       h_printer.Print("\n");
312     }
313   };
314   for (const auto& kv : fwd_decls) {
315     const std::string& package = kv.first;
316     if (package != last_package) {
317       close_last_package();
318       last_package = package;
319       for (const std::string& ns : SplitString(package, "::"))
320         h_printer.Print("namespace $ns$ {\n", "ns", ns);
321     }
322     h_printer.Print("$decl$\n", "decl", kv.second);
323   }
324   close_last_package();
325 
326   // -- End of fwd declarations.
327 
328   for (const std::string& ns : GetNamespaces(file)) {
329     h_printer.Print("namespace $n$ {\n", "n", ns);
330     cc_printer.Print("namespace $n$ {\n", "n", ns);
331   }
332 
333   // Generate declarations and definitions.
334   for (const EnumDescriptor* enm : local_enums)
335     GenEnum(enm, &h_printer);
336 
337   for (const Descriptor* msg : local_types) {
338     GenClassDecl(msg, &h_printer);
339     GenClassDef(msg, &cc_printer);
340   }
341 
342   for (const std::string& ns : GetNamespaces(file)) {
343     h_printer.Print("}  // namespace $n$\n", "n", ns);
344     cc_printer.Print("}  // namespace $n$\n", "n", ns);
345   }
346   cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
347   cc_printer.Print("#pragma GCC diagnostic pop\n");
348   cc_printer.Print("#endif\n");
349 
350   h_printer.Print("\n#endif  // $g$\n", "g", include_guard);
351 
352   return true;
353 }
354 
GetCppType(const FieldDescriptor * field,bool constref) const355 std::string CppObjGenerator::GetCppType(const FieldDescriptor* field,
356                                         bool constref) const {
357   switch (field->type()) {
358     case FieldDescriptor::TYPE_DOUBLE:
359       return "double";
360     case FieldDescriptor::TYPE_FLOAT:
361       return "float";
362     case FieldDescriptor::TYPE_FIXED32:
363     case FieldDescriptor::TYPE_UINT32:
364       return "uint32_t";
365     case FieldDescriptor::TYPE_SFIXED32:
366     case FieldDescriptor::TYPE_INT32:
367     case FieldDescriptor::TYPE_SINT32:
368       return "int32_t";
369     case FieldDescriptor::TYPE_FIXED64:
370     case FieldDescriptor::TYPE_UINT64:
371       return "uint64_t";
372     case FieldDescriptor::TYPE_SFIXED64:
373     case FieldDescriptor::TYPE_SINT64:
374     case FieldDescriptor::TYPE_INT64:
375       return "int64_t";
376     case FieldDescriptor::TYPE_BOOL:
377       return "bool";
378     case FieldDescriptor::TYPE_STRING:
379     case FieldDescriptor::TYPE_BYTES:
380       return constref ? "const std::string&" : "std::string";
381     case FieldDescriptor::TYPE_MESSAGE:
382       assert(!field->options().lazy());
383       return constref
384                  ? "const " +
385                        GetFullName(field->message_type(),
386                                    !HasSamePackage(field->message_type())) +
387                        "&"
388                  : GetFullName(field->message_type(),
389                                !HasSamePackage(field->message_type()));
390     case FieldDescriptor::TYPE_ENUM:
391       return GetFullName(field->enum_type(),
392                          !HasSamePackage(field->enum_type()));
393     case FieldDescriptor::TYPE_GROUP:
394       abort();
395   }
396   abort();  // for gcc
397 }
398 
GetProtozeroSetter(const FieldDescriptor * field) const399 std::string CppObjGenerator::GetProtozeroSetter(
400     const FieldDescriptor* field) const {
401   switch (field->type()) {
402     case FieldDescriptor::TYPE_BOOL:
403       return "::protozero::internal::gen_helpers::SerializeTinyVarInt";
404     case FieldDescriptor::TYPE_INT32:
405     case FieldDescriptor::TYPE_INT64:
406     case FieldDescriptor::TYPE_UINT32:
407     case FieldDescriptor::TYPE_UINT64:
408     case FieldDescriptor::TYPE_ENUM:
409       return "::protozero::internal::gen_helpers::SerializeVarInt";
410     case FieldDescriptor::TYPE_SINT32:
411     case FieldDescriptor::TYPE_SINT64:
412       return "::protozero::internal::gen_helpers::SerializeSignedVarInt";
413     case FieldDescriptor::TYPE_FIXED32:
414     case FieldDescriptor::TYPE_FIXED64:
415     case FieldDescriptor::TYPE_SFIXED32:
416     case FieldDescriptor::TYPE_SFIXED64:
417     case FieldDescriptor::TYPE_FLOAT:
418     case FieldDescriptor::TYPE_DOUBLE:
419       return "::protozero::internal::gen_helpers::SerializeFixed";
420     case FieldDescriptor::TYPE_STRING:
421     case FieldDescriptor::TYPE_BYTES:
422       return "::protozero::internal::gen_helpers::SerializeString";
423     case FieldDescriptor::TYPE_GROUP:
424     case FieldDescriptor::TYPE_MESSAGE:
425       abort();
426   }
427   abort();
428 }
429 
GetPackedBuffer(const FieldDescriptor * field) const430 std::string CppObjGenerator::GetPackedBuffer(
431     const FieldDescriptor* field) const {
432   switch (field->type()) {
433     case FieldDescriptor::TYPE_FIXED32:
434       return "::protozero::PackedFixedSizeInt<uint32_t>";
435     case FieldDescriptor::TYPE_SFIXED32:
436       return "::protozero::PackedFixedSizeInt<int32_t>";
437     case FieldDescriptor::TYPE_FIXED64:
438       return "::protozero::PackedFixedSizeInt<uint64_t>";
439     case FieldDescriptor::TYPE_SFIXED64:
440       return "::protozero::PackedFixedSizeInt<int64_t>";
441     case FieldDescriptor::TYPE_DOUBLE:
442       return "::protozero::PackedFixedSizeInt<double>";
443     case FieldDescriptor::TYPE_FLOAT:
444       return "::protozero::PackedFixedSizeInt<float>";
445     case FieldDescriptor::TYPE_INT32:
446     case FieldDescriptor::TYPE_SINT32:
447     case FieldDescriptor::TYPE_UINT32:
448     case FieldDescriptor::TYPE_INT64:
449     case FieldDescriptor::TYPE_UINT64:
450     case FieldDescriptor::TYPE_SINT64:
451     case FieldDescriptor::TYPE_BOOL:
452     case FieldDescriptor::TYPE_ENUM:
453       return "::protozero::PackedVarInt";
454     case FieldDescriptor::TYPE_STRING:
455     case FieldDescriptor::TYPE_BYTES:
456     case FieldDescriptor::TYPE_MESSAGE:
457     case FieldDescriptor::TYPE_GROUP:
458       break;  // Will abort()
459   }
460   abort();
461 }
462 
GetPackedWireType(const FieldDescriptor * field) const463 std::string CppObjGenerator::GetPackedWireType(
464     const FieldDescriptor* field) const {
465   switch (field->type()) {
466     case FieldDescriptor::TYPE_FIXED32:
467     case FieldDescriptor::TYPE_SFIXED32:
468     case FieldDescriptor::TYPE_FLOAT:
469       return "::protozero::proto_utils::ProtoWireType::kFixed32";
470     case FieldDescriptor::TYPE_FIXED64:
471     case FieldDescriptor::TYPE_SFIXED64:
472     case FieldDescriptor::TYPE_DOUBLE:
473       return "::protozero::proto_utils::ProtoWireType::kFixed64";
474     case FieldDescriptor::TYPE_INT32:
475     case FieldDescriptor::TYPE_SINT32:
476     case FieldDescriptor::TYPE_UINT32:
477     case FieldDescriptor::TYPE_INT64:
478     case FieldDescriptor::TYPE_UINT64:
479     case FieldDescriptor::TYPE_SINT64:
480     case FieldDescriptor::TYPE_BOOL:
481     case FieldDescriptor::TYPE_ENUM:
482       return "::protozero::proto_utils::ProtoWireType::kVarInt";
483     case FieldDescriptor::TYPE_STRING:
484     case FieldDescriptor::TYPE_BYTES:
485     case FieldDescriptor::TYPE_MESSAGE:
486     case FieldDescriptor::TYPE_GROUP:
487       break;  // Will abort()
488   }
489   abort();
490 }
491 
GenEnum(const EnumDescriptor * enum_desc,Printer * p) const492 void CppObjGenerator::GenEnum(const EnumDescriptor* enum_desc,
493                               Printer* p) const {
494   std::string full_name = GetFullName(enum_desc);
495 
496   // When generating enums, there are two cases:
497   // 1. Enums nested in a message (most frequent case), e.g.:
498   //    message MyMsg { enum MyEnum { FOO=1; BAR=2; } }
499   // 2. Enum defined at the package level, outside of any message.
500   //
501   // In the case 1, the C++ code generated by the official protobuf library is:
502   // enum MyEnum {  MyMsg_MyEnum_FOO=1, MyMsg_MyEnum_BAR=2 }
503   // class MyMsg { static const auto FOO = MyMsg_MyEnum_FOO; ... same for BAR }
504   //
505   // In the case 2, the C++ code is simply:
506   // enum MyEnum { FOO=1, BAR=2 }
507   // Hence this |prefix| logic.
508   std::string prefix = enum_desc->containing_type() ? full_name + "_" : "";
509   p->Print("enum $f$ : int {\n", "f", full_name);
510   for (int e = 0; e < enum_desc->value_count(); e++) {
511     const EnumValueDescriptor* value = enum_desc->value(e);
512     p->Print("  $p$$n$ = $v$,\n", "p", prefix, "n", value->name(), "v",
513              std::to_string(value->number()));
514   }
515   p->Print("};\n");
516 }
517 
GenEnumAliases(const EnumDescriptor * enum_desc,Printer * p) const518 void CppObjGenerator::GenEnumAliases(const EnumDescriptor* enum_desc,
519                                      Printer* p) const {
520   int min_value = std::numeric_limits<int>::max();
521   int max_value = std::numeric_limits<int>::min();
522   std::string min_name;
523   std::string max_name;
524   std::string full_name = GetFullName(enum_desc);
525   for (int e = 0; e < enum_desc->value_count(); e++) {
526     const EnumValueDescriptor* value = enum_desc->value(e);
527     p->Print("static constexpr auto $n$ = $f$_$n$;\n", "f", full_name, "n",
528              value->name());
529     if (value->number() < min_value) {
530       min_value = value->number();
531       min_name = full_name + "_" + value->name();
532     }
533     if (value->number() > max_value) {
534       max_value = value->number();
535       max_name = full_name + "_" + value->name();
536     }
537   }
538   p->Print("static constexpr auto $n$_MIN = $m$;\n", "n", enum_desc->name(),
539            "m", min_name);
540   p->Print("static constexpr auto $n$_MAX = $m$;\n", "n", enum_desc->name(),
541            "m", max_name);
542 }
543 
GenClassDecl(const Descriptor * msg,Printer * p) const544 void CppObjGenerator::GenClassDecl(const Descriptor* msg, Printer* p) const {
545   std::string full_name = GetFullName(msg);
546   p->Print(
547       "\nclass PERFETTO_EXPORT_COMPONENT $n$ : public "
548       "::protozero::CppMessageObj {\n",
549       "n", full_name);
550   p->Print(" public:\n");
551   p->Indent();
552 
553   // Do a first pass to generate aliases for nested types.
554   // e.g., using Foo = Parent_Foo;
555   for (int i = 0; i < msg->nested_type_count(); i++) {
556     const Descriptor* nested_msg = msg->nested_type(i);
557     p->Print("using $n$ = $f$;\n", "n", nested_msg->name(), "f",
558              GetFullName(nested_msg));
559   }
560   for (int i = 0; i < msg->enum_type_count(); i++) {
561     const EnumDescriptor* nested_enum = msg->enum_type(i);
562     p->Print("using $n$ = $f$;\n", "n", nested_enum->name(), "f",
563              GetFullName(nested_enum));
564     GenEnumAliases(nested_enum, p);
565   }
566 
567   // Generate constants with field numbers.
568   p->Print("enum FieldNumbers {\n");
569   for (int i = 0; i < msg->field_count(); i++) {
570     const FieldDescriptor* field = msg->field(i);
571     std::string name = field->camelcase_name();
572     name[0] = perfetto::base::Uppercase(name[0]);
573     p->Print("  k$n$FieldNumber = $num$,\n", "n", name, "num",
574              std::to_string(field->number()));
575   }
576   p->Print("};\n\n");
577 
578   p->Print("$n$();\n", "n", full_name);
579   p->Print("~$n$() override;\n", "n", full_name);
580   p->Print("$n$($n$&&) noexcept;\n", "n", full_name);
581   p->Print("$n$& operator=($n$&&);\n", "n", full_name);
582   p->Print("$n$(const $n$&);\n", "n", full_name);
583   p->Print("$n$& operator=(const $n$&);\n", "n", full_name);
584   p->Print("bool operator==(const $n$&) const;\n", "n", full_name);
585   p->Print(
586       "bool operator!=(const $n$& other) const { return !(*this == other); }\n",
587       "n", full_name);
588   p->Print("\n");
589 
590   std::string proto_type = GetFullName(msg, true);
591   p->Print("bool ParseFromArray(const void*, size_t) override;\n");
592   p->Print("std::string SerializeAsString() const override;\n");
593   p->Print("std::vector<uint8_t> SerializeAsArray() const override;\n");
594   p->Print("void Serialize(::protozero::Message*) const;\n");
595 
596   // Generate accessors.
597   for (int i = 0; i < msg->field_count(); i++) {
598     const FieldDescriptor* field = msg->field(i);
599     auto set_bit = "_has_field_.set(" + std::to_string(field->number()) + ")";
600     p->Print("\n");
601     if (field->options().lazy()) {
602       p->Print("const std::string& $n$_raw() const { return $n$_; }\n", "n",
603                field->lowercase_name());
604       p->Print(
605           "void set_$n$_raw(const std::string& raw) { $n$_ = raw; $s$; }\n",
606           "n", field->lowercase_name(), "s", set_bit);
607     } else if (!field->is_repeated()) {
608       p->Print("bool has_$n$() const { return _has_field_[$bit$]; }\n", "n",
609                field->lowercase_name(), "bit", std::to_string(field->number()));
610       if (field->type() == TYPE_MESSAGE) {
611         p->Print("$t$ $n$() const { return *$n$_; }\n", "t",
612                  GetCppType(field, true), "n", field->lowercase_name());
613         p->Print("$t$* mutable_$n$() { $s$; return $n$_.get(); }\n", "t",
614                  GetCppType(field, false), "n", field->lowercase_name(), "s",
615                  set_bit);
616       } else {
617         p->Print("$t$ $n$() const { return $n$_; }\n", "t",
618                  GetCppType(field, true), "n", field->lowercase_name());
619         p->Print("void set_$n$($t$ value) { $n$_ = value; $s$; }\n", "t",
620                  GetCppType(field, true), "n", field->lowercase_name(), "s",
621                  set_bit);
622         if (field->type() == FieldDescriptor::TYPE_BYTES) {
623           p->Print(
624               "void set_$n$(const void* p, size_t s) { "
625               "$n$_.assign(reinterpret_cast<const char*>(p), s); $s$; }\n",
626               "n", field->lowercase_name(), "s", set_bit);
627         }
628       }
629     } else {  // is_repeated()
630       p->Print("const std::vector<$t$>& $n$() const { return $n$_; }\n", "t",
631                GetCppType(field, false), "n", field->lowercase_name());
632       p->Print("std::vector<$t$>* mutable_$n$() { return &$n$_; }\n", "t",
633                GetCppType(field, false), "n", field->lowercase_name());
634 
635       // Generate accessors for repeated message types in the .cc file so that
636       // the header doesn't depend on the full definition of all nested types.
637       if (field->type() == TYPE_MESSAGE) {
638         p->Print("int $n$_size() const;\n", "t", GetCppType(field, false), "n",
639                  field->lowercase_name());
640         p->Print("void clear_$n$();\n", "n", field->lowercase_name());
641         p->Print("$t$* add_$n$();\n", "t", GetCppType(field, false), "n",
642                  field->lowercase_name());
643       } else {  // Primitive type.
644         p->Print(
645             "int $n$_size() const { return static_cast<int>($n$_.size()); }\n",
646             "t", GetCppType(field, false), "n", field->lowercase_name());
647         p->Print("void clear_$n$() { $n$_.clear(); }\n", "n",
648                  field->lowercase_name());
649         p->Print("void add_$n$($t$ value) { $n$_.emplace_back(value); }\n", "t",
650                  GetCppType(field, false), "n", field->lowercase_name());
651         // TODO(primiano): this should be done only for TYPE_MESSAGE.
652         // Unfortuntely we didn't realize before and now we have a bunch of code
653         // that does: *msg->add_int_value() = 42 instead of
654         // msg->add_int_value(42).
655         p->Print(
656             "$t$* add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
657             "t", GetCppType(field, false), "n", field->lowercase_name());
658       }
659     }
660   }
661   p->Outdent();
662   p->Print("\n private:\n");
663   p->Indent();
664 
665   // Generate fields.
666   int max_field_id = 1;
667   for (int i = 0; i < msg->field_count(); i++) {
668     const FieldDescriptor* field = msg->field(i);
669     max_field_id = std::max(max_field_id, field->number());
670     if (field->options().lazy()) {
671       p->Print("std::string $n$_;  // [lazy=true]\n", "n",
672                field->lowercase_name());
673     } else if (!field->is_repeated()) {
674       std::string type = GetCppType(field, false);
675       if (field->type() == TYPE_MESSAGE) {
676         type = "::protozero::CopyablePtr<" + type + ">";
677         p->Print("$t$ $n$_;\n", "t", type, "n", field->lowercase_name());
678       } else {
679         p->Print("$t$ $n$_{};\n", "t", type, "n", field->lowercase_name());
680       }
681     } else {  // is_repeated()
682       p->Print("std::vector<$t$> $n$_;\n", "t", GetCppType(field, false), "n",
683                field->lowercase_name());
684     }
685   }
686   p->Print("\n");
687   p->Print("// Allows to preserve unknown protobuf fields for compatibility\n");
688   p->Print("// with future versions of .proto files.\n");
689   p->Print("std::string unknown_fields_;\n");
690 
691   p->Print("\nstd::bitset<$id$> _has_field_{};\n", "id",
692            std::to_string(max_field_id + 1));
693 
694   p->Outdent();
695   p->Print("};\n\n");
696 }
697 
GenClassDef(const Descriptor * msg,Printer * p) const698 void CppObjGenerator::GenClassDef(const Descriptor* msg, Printer* p) const {
699   p->Print("\n");
700   std::string full_name = GetFullName(msg);
701 
702   p->Print("$n$::$n$() = default;\n", "n", full_name);
703   p->Print("$n$::~$n$() = default;\n", "n", full_name);
704   p->Print("$n$::$n$(const $n$&) = default;\n", "n", full_name);
705   p->Print("$n$& $n$::operator=(const $n$&) = default;\n", "n", full_name);
706   p->Print("$n$::$n$($n$&&) noexcept = default;\n", "n", full_name);
707   p->Print("$n$& $n$::operator=($n$&&) = default;\n", "n", full_name);
708 
709   p->Print("\n");
710 
711   // Comparison operator
712   p->Print("bool $n$::operator==(const $n$& other) const {\n", "n", full_name);
713   p->Indent();
714 
715   p->Print(
716       "return ::protozero::internal::gen_helpers::EqualsField(unknown_fields_, "
717       "other.unknown_fields_)");
718   for (int i = 0; i < msg->field_count(); i++)
719     p->Print(
720         "\n && ::protozero::internal::gen_helpers::EqualsField($n$_, "
721         "other.$n$_)",
722         "n", msg->field(i)->lowercase_name());
723   p->Print(";");
724   p->Outdent();
725   p->Print("\n}\n\n");
726 
727   // Accessors for repeated message fields.
728   for (int i = 0; i < msg->field_count(); i++) {
729     const FieldDescriptor* field = msg->field(i);
730     if (field->options().lazy() || !field->is_repeated() ||
731         field->type() != TYPE_MESSAGE) {
732       continue;
733     }
734     p->Print(
735         "int $c$::$n$_size() const { return static_cast<int>($n$_.size()); }\n",
736         "c", full_name, "t", GetCppType(field, false), "n",
737         field->lowercase_name());
738     p->Print("void $c$::clear_$n$() { $n$_.clear(); }\n", "c", full_name, "n",
739              field->lowercase_name());
740     p->Print(
741         "$t$* $c$::add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
742         "c", full_name, "t", GetCppType(field, false), "n",
743         field->lowercase_name());
744   }
745 
746   std::string proto_type = GetFullName(msg, true);
747 
748   // Generate the ParseFromArray() method definition.
749   p->Print("bool $f$::ParseFromArray(const void* raw, size_t size) {\n", "f",
750            full_name);
751   p->Indent();
752   for (int i = 0; i < msg->field_count(); i++) {
753     const FieldDescriptor* field = msg->field(i);
754     if (field->is_repeated())
755       p->Print("$n$_.clear();\n", "n", field->lowercase_name());
756   }
757   p->Print("unknown_fields_.clear();\n");
758   p->Print("bool packed_error = false;\n");
759   p->Print("\n");
760   p->Print("::protozero::ProtoDecoder dec(raw, size);\n");
761   p->Print("for (auto field = dec.ReadField(); field.valid(); ");
762   p->Print("field = dec.ReadField()) {\n");
763   p->Indent();
764   p->Print("if (field.id() < _has_field_.size()) {\n");
765   p->Print("  _has_field_.set(field.id());\n");
766   p->Print("}\n");
767   p->Print("switch (field.id()) {\n");
768   p->Indent();
769   for (int i = 0; i < msg->field_count(); i++) {
770     const FieldDescriptor* field = msg->field(i);
771     p->Print("case $id$ /* $n$ */:\n", "id", std::to_string(field->number()),
772              "n", field->lowercase_name());
773     p->Indent();
774     if (field->options().lazy()) {
775       p->Print(
776           "::protozero::internal::gen_helpers::DeserializeString(field, "
777           "&$n$_);\n",
778           "n", field->lowercase_name());
779     } else {
780       std::string statement;
781       if (field->type() == TYPE_MESSAGE) {
782         statement = "$rval$.ParseFromArray(field.data(), field.size());\n";
783       } else {
784         if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
785           // sint32/64 fields are special and need to be zig-zag-decoded.
786           statement = "field.get_signed(&$rval$);\n";
787         } else if (field->type() == TYPE_STRING) {
788           statement =
789               "::protozero::internal::gen_helpers::DeserializeString(field, "
790               "&$rval$);\n";
791         } else {
792           statement = "field.get(&$rval$);\n";
793         }
794       }
795       if (field->is_packed()) {
796         PERFETTO_CHECK(field->is_repeated());
797         if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
798           PERFETTO_FATAL("packed signed (zigzag) fields are not supported");
799         }
800         p->Print(
801             "if "
802             "(!::protozero::internal::gen_helpers::DeserializePackedRepeated"
803             "<$w$, $c$>(field, &$n$_)) {\n",
804             "w", GetPackedWireType(field), "c", GetCppType(field, false), "n",
805             field->lowercase_name());
806         p->Print("  packed_error = true;");
807         p->Print("}\n");
808       } else if (field->is_repeated()) {
809         p->Print("$n$_.emplace_back();\n", "n", field->lowercase_name());
810         p->Print(statement.c_str(), "rval",
811                  field->lowercase_name() + "_.back()");
812       } else if (field->type() == TYPE_MESSAGE) {
813         p->Print(statement.c_str(), "rval",
814                  "(*" + field->lowercase_name() + "_)");
815       } else {
816         p->Print(statement.c_str(), "rval", field->lowercase_name() + "_");
817       }
818     }
819     p->Print("break;\n");
820     p->Outdent();
821   }  // for (field)
822   p->Print("default:\n");
823   p->Print("  field.SerializeAndAppendTo(&unknown_fields_);\n");
824   p->Print("  break;\n");
825   p->Outdent();
826   p->Print("}\n");  // switch(field.id)
827   p->Outdent();
828   p->Print("}\n");                                           // for(field)
829   p->Print("return !packed_error && !dec.bytes_left();\n");  // for(field)
830   p->Outdent();
831   p->Print("}\n\n");
832 
833   // Generate the SerializeAsString() method definition.
834   p->Print("std::string $f$::SerializeAsString() const {\n", "f", full_name);
835   p->Indent();
836   p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
837   p->Print("Serialize(msg.get());\n");
838   p->Print("return msg.SerializeAsString();\n");
839   p->Outdent();
840   p->Print("}\n\n");
841 
842   // Generate the SerializeAsArray() method definition.
843   p->Print("std::vector<uint8_t> $f$::SerializeAsArray() const {\n", "f",
844            full_name);
845   p->Indent();
846   p->Print("::protozero::internal::gen_helpers::MessageSerializer msg;\n");
847   p->Print("Serialize(msg.get());\n");
848   p->Print("return msg.SerializeAsArray();\n");
849   p->Outdent();
850   p->Print("}\n\n");
851 
852   // Generate the Serialize() method that writes the fields into the passed
853   // protozero |msg| write-only interface |msg|.
854   p->Print("void $f$::Serialize(::protozero::Message* msg) const {\n", "f",
855            full_name);
856   p->Indent();
857   for (int i = 0; i < msg->field_count(); i++) {
858     const FieldDescriptor* field = msg->field(i);
859     std::map<std::string, std::string> args;
860     args["id"] = std::to_string(field->number());
861     args["n"] = field->lowercase_name();
862     p->Print(args, "// Field $id$: $n$\n");
863     if (field->is_packed()) {
864       PERFETTO_CHECK(field->is_repeated());
865       p->Print("{\n");
866       p->Indent();
867       p->Print("$p$ pack;\n", "p", GetPackedBuffer(field));
868       p->Print(args, "for (auto& it : $n$_)\n");
869       p->Print(args, "  pack.Append(it);\n");
870       p->Print(args, "msg->AppendBytes($id$, pack.data(), pack.size());\n");
871       p->Outdent();
872       p->Print("}\n");
873     } else {
874       if (field->is_repeated()) {
875         p->Print(args, "for (auto& it : $n$_) {\n");
876         args["lvalue"] = "it";
877         args["rvalue"] = "it";
878       } else {
879         p->Print(args, "if (_has_field_[$id$]) {\n");
880         args["lvalue"] = "(*" + field->lowercase_name() + "_)";
881         args["rvalue"] = field->lowercase_name() + "_";
882       }
883       p->Indent();
884       if (field->options().lazy()) {
885         p->Print(args, "msg->AppendString($id$, $rvalue$);\n");
886       } else if (field->type() == TYPE_MESSAGE) {
887         p->Print(args,
888                  "$lvalue$.Serialize("
889                  "msg->BeginNestedMessage<::protozero::Message>($id$));\n");
890       } else {
891         args["setter"] = GetProtozeroSetter(field);
892         p->Print(args, "$setter$($id$, $rvalue$, msg);\n");
893       }
894       p->Outdent();
895       p->Print("}\n");
896     }
897 
898     p->Print("\n");
899   }  // for (field)
900   p->Print(
901       "protozero::internal::gen_helpers::SerializeUnknownFields(unknown_fields_"
902       ", msg);\n");
903   p->Outdent();
904   p->Print("}\n\n");
905 }
906 
907 }  // namespace
908 }  // namespace protozero
909 
main(int argc,char ** argv)910 int main(int argc, char** argv) {
911   ::protozero::CppObjGenerator generator;
912   return google::protobuf::compiler::PluginMain(argc, argv, &generator);
913 }
914