xref: /aosp_15_r20/external/perfetto/src/protozero/protoc_plugin/protozero_plugin.cc (revision 6dbdd20afdafa5e3ca9b8809fa73465d530080dc)
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdlib.h>
18 
19 #include <limits>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 
25 #include <google/protobuf/compiler/code_generator.h>
26 #include <google/protobuf/compiler/plugin.h>
27 #include <google/protobuf/descriptor.h>
28 #include <google/protobuf/descriptor.pb.h>
29 #include <google/protobuf/io/printer.h>
30 #include <google/protobuf/io/zero_copy_stream.h>
31 
32 #include "perfetto/ext/base/string_utils.h"
33 
34 namespace protozero {
35 namespace {
36 
37 using google::protobuf::Descriptor;
38 using google::protobuf::EnumDescriptor;
39 using google::protobuf::EnumValueDescriptor;
40 using google::protobuf::FieldDescriptor;
41 using google::protobuf::FileDescriptor;
42 using google::protobuf::compiler::GeneratorContext;
43 using google::protobuf::io::Printer;
44 using google::protobuf::io::ZeroCopyOutputStream;
45 using perfetto::base::ReplaceAll;
46 using perfetto::base::SplitString;
47 using perfetto::base::StripChars;
48 using perfetto::base::StripPrefix;
49 using perfetto::base::StripSuffix;
50 using perfetto::base::ToUpper;
51 using perfetto::base::Uppercase;
52 
53 // Keep this value in sync with ProtoDecoder::kMaxDecoderFieldId. If they go out
54 // of sync pbzero.h files will stop compiling, hitting the at() static_assert.
55 // Not worth an extra dependency.
56 constexpr int kMaxDecoderFieldId = 999;
57 
Assert(bool condition)58 void Assert(bool condition) {
59   if (!condition)
60     abort();
61 }
62 
63 struct FileDescriptorComp {
operator ()protozero::__anonf52e20e30111::FileDescriptorComp64   bool operator()(const FileDescriptor* lhs, const FileDescriptor* rhs) const {
65     int comp = lhs->name().compare(rhs->name());
66     Assert(comp != 0 || lhs == rhs);
67     return comp < 0;
68   }
69 };
70 
71 struct DescriptorComp {
operator ()protozero::__anonf52e20e30111::DescriptorComp72   bool operator()(const Descriptor* lhs, const Descriptor* rhs) const {
73     int comp = lhs->full_name().compare(rhs->full_name());
74     Assert(comp != 0 || lhs == rhs);
75     return comp < 0;
76   }
77 };
78 
79 struct EnumDescriptorComp {
operator ()protozero::__anonf52e20e30111::EnumDescriptorComp80   bool operator()(const EnumDescriptor* lhs, const EnumDescriptor* rhs) const {
81     int comp = lhs->full_name().compare(rhs->full_name());
82     Assert(comp != 0 || lhs == rhs);
83     return comp < 0;
84   }
85 };
86 
ProtoStubName(const FileDescriptor * proto)87 inline std::string ProtoStubName(const FileDescriptor* proto) {
88   return StripSuffix(proto->name(), ".proto") + ".pbzero";
89 }
90 
91 class GeneratorJob {
92  public:
GeneratorJob(const FileDescriptor * file,Printer * stub_h_printer)93   GeneratorJob(const FileDescriptor* file, Printer* stub_h_printer)
94       : source_(file), stub_h_(stub_h_printer) {}
95 
GenerateStubs()96   bool GenerateStubs() {
97     Preprocess();
98     GeneratePrologue();
99     for (const EnumDescriptor* enumeration : enums_)
100       GenerateEnumDescriptor(enumeration);
101     for (const Descriptor* message : messages_)
102       GenerateMessageDescriptor(message);
103     for (const auto& key_value : extensions_)
104       GenerateExtension(key_value.first, key_value.second);
105     GenerateEpilogue();
106     return error_.empty();
107   }
108 
SetOption(const std::string & name,const std::string & value)109   void SetOption(const std::string& name, const std::string& value) {
110     if (name == "wrapper_namespace") {
111       wrapper_namespace_ = value;
112     } else if (name == "sdk") {
113       sdk_mode_ = (value == "true" || value == "1");
114     } else {
115       Abort(std::string() + "Unknown plugin option '" + name + "'.");
116     }
117   }
118 
119   // If generator fails to produce stubs for a particular proto definitions
120   // it finishes with undefined output and writes the first error occurred.
GetFirstError() const121   const std::string& GetFirstError() const { return error_; }
122 
123  private:
124   // Only the first error will be recorded.
Abort(const std::string & reason)125   void Abort(const std::string& reason) {
126     if (error_.empty())
127       error_ = reason;
128   }
129 
130   template <class T>
HasSamePackage(const T * descriptor) const131   bool HasSamePackage(const T* descriptor) const {
132     return descriptor->file()->package() == package_;
133   }
134 
135   // Get C++ class name corresponding to proto descriptor.
136   // Nested names are splitted by underscores. Underscores in type names aren't
137   // prohibited but not recommended in order to avoid name collisions.
138   template <class T>
GetCppClassName(const T * descriptor,bool full=false)139   inline std::string GetCppClassName(const T* descriptor, bool full = false) {
140     std::string package = descriptor->file()->package();
141     std::string name = StripPrefix(descriptor->full_name(), package + ".");
142     name = StripChars(name, ".", '_');
143 
144     if (full && !package.empty()) {
145       auto get_full_namespace = [&]() {
146         std::vector<std::string> namespaces = SplitString(package, ".");
147         if (!wrapper_namespace_.empty())
148           namespaces.push_back(wrapper_namespace_);
149 
150         std::string result = "";
151         for (const std::string& ns : namespaces) {
152           result += "::";
153           result += ns;
154         }
155         return result;
156       };
157 
158       std::string namespaces = ReplaceAll(package, ".", "::");
159       name = get_full_namespace() + "::" + name;
160     }
161 
162     return name;
163   }
164 
GetFieldNumberConstant(const FieldDescriptor * field)165   inline std::string GetFieldNumberConstant(const FieldDescriptor* field) {
166     std::string name = field->camelcase_name();
167     if (!name.empty()) {
168       name.at(0) = Uppercase(name.at(0));
169       name = "k" + name + "FieldNumber";
170     } else {
171       // Protoc allows fields like 'bool _ = 1'.
172       Abort("Empty field name in camel case notation.");
173     }
174     return name;
175   }
176 
177   // Note: intentionally avoiding depending on protozero sources, as well as
178   // protobuf-internal WireFormat/WireFormatLite classes.
FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type)179   const char* FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type) {
180     switch (proto_type) {
181       case FieldDescriptor::TYPE_INT64:
182       case FieldDescriptor::TYPE_UINT64:
183       case FieldDescriptor::TYPE_INT32:
184       case FieldDescriptor::TYPE_BOOL:
185       case FieldDescriptor::TYPE_UINT32:
186       case FieldDescriptor::TYPE_ENUM:
187       case FieldDescriptor::TYPE_SINT32:
188       case FieldDescriptor::TYPE_SINT64:
189         return "::protozero::proto_utils::ProtoWireType::kVarInt";
190 
191       case FieldDescriptor::TYPE_FIXED32:
192       case FieldDescriptor::TYPE_SFIXED32:
193       case FieldDescriptor::TYPE_FLOAT:
194         return "::protozero::proto_utils::ProtoWireType::kFixed32";
195 
196       case FieldDescriptor::TYPE_FIXED64:
197       case FieldDescriptor::TYPE_SFIXED64:
198       case FieldDescriptor::TYPE_DOUBLE:
199         return "::protozero::proto_utils::ProtoWireType::kFixed64";
200 
201       case FieldDescriptor::TYPE_STRING:
202       case FieldDescriptor::TYPE_MESSAGE:
203       case FieldDescriptor::TYPE_BYTES:
204         return "::protozero::proto_utils::ProtoWireType::kLengthDelimited";
205 
206       case FieldDescriptor::TYPE_GROUP:
207         Abort("Groups not supported.");
208     }
209     Abort("Unrecognized FieldDescriptor::Type.");
210     return "";
211   }
212 
FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type)213   const char* FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type) {
214     switch (proto_type) {
215       case FieldDescriptor::TYPE_INT64:
216       case FieldDescriptor::TYPE_UINT64:
217       case FieldDescriptor::TYPE_INT32:
218       case FieldDescriptor::TYPE_BOOL:
219       case FieldDescriptor::TYPE_UINT32:
220       case FieldDescriptor::TYPE_ENUM:
221       case FieldDescriptor::TYPE_SINT32:
222       case FieldDescriptor::TYPE_SINT64:
223         return "::protozero::PackedVarInt";
224 
225       case FieldDescriptor::TYPE_FIXED32:
226         return "::protozero::PackedFixedSizeInt<uint32_t>";
227       case FieldDescriptor::TYPE_SFIXED32:
228         return "::protozero::PackedFixedSizeInt<int32_t>";
229       case FieldDescriptor::TYPE_FLOAT:
230         return "::protozero::PackedFixedSizeInt<float>";
231 
232       case FieldDescriptor::TYPE_FIXED64:
233         return "::protozero::PackedFixedSizeInt<uint64_t>";
234       case FieldDescriptor::TYPE_SFIXED64:
235         return "::protozero::PackedFixedSizeInt<int64_t>";
236       case FieldDescriptor::TYPE_DOUBLE:
237         return "::protozero::PackedFixedSizeInt<double>";
238 
239       case FieldDescriptor::TYPE_STRING:
240       case FieldDescriptor::TYPE_MESSAGE:
241       case FieldDescriptor::TYPE_BYTES:
242       case FieldDescriptor::TYPE_GROUP:
243         Abort("Unexpected FieldDescritor::Type.");
244     }
245     Abort("Unrecognized FieldDescriptor::Type.");
246     return "";
247   }
248 
FieldToProtoSchemaType(const FieldDescriptor * field)249   const char* FieldToProtoSchemaType(const FieldDescriptor* field) {
250     switch (field->type()) {
251       case FieldDescriptor::TYPE_BOOL:
252         return "kBool";
253       case FieldDescriptor::TYPE_INT32:
254         return "kInt32";
255       case FieldDescriptor::TYPE_INT64:
256         return "kInt64";
257       case FieldDescriptor::TYPE_UINT32:
258         return "kUint32";
259       case FieldDescriptor::TYPE_UINT64:
260         return "kUint64";
261       case FieldDescriptor::TYPE_SINT32:
262         return "kSint32";
263       case FieldDescriptor::TYPE_SINT64:
264         return "kSint64";
265       case FieldDescriptor::TYPE_FIXED32:
266         return "kFixed32";
267       case FieldDescriptor::TYPE_FIXED64:
268         return "kFixed64";
269       case FieldDescriptor::TYPE_SFIXED32:
270         return "kSfixed32";
271       case FieldDescriptor::TYPE_SFIXED64:
272         return "kSfixed64";
273       case FieldDescriptor::TYPE_FLOAT:
274         return "kFloat";
275       case FieldDescriptor::TYPE_DOUBLE:
276         return "kDouble";
277       case FieldDescriptor::TYPE_ENUM:
278         return "kEnum";
279       case FieldDescriptor::TYPE_STRING:
280         return "kString";
281       case FieldDescriptor::TYPE_MESSAGE:
282         return "kMessage";
283       case FieldDescriptor::TYPE_BYTES:
284         return "kBytes";
285 
286       case FieldDescriptor::TYPE_GROUP:
287         Abort("Groups not supported.");
288         return "";
289     }
290     Abort("Unrecognized FieldDescriptor::Type.");
291     return "";
292   }
293 
FieldToCppTypeName(const FieldDescriptor * field)294   std::string FieldToCppTypeName(const FieldDescriptor* field) {
295     switch (field->type()) {
296       case FieldDescriptor::TYPE_BOOL:
297         return "bool";
298       case FieldDescriptor::TYPE_INT32:
299         return "int32_t";
300       case FieldDescriptor::TYPE_INT64:
301         return "int64_t";
302       case FieldDescriptor::TYPE_UINT32:
303         return "uint32_t";
304       case FieldDescriptor::TYPE_UINT64:
305         return "uint64_t";
306       case FieldDescriptor::TYPE_SINT32:
307         return "int32_t";
308       case FieldDescriptor::TYPE_SINT64:
309         return "int64_t";
310       case FieldDescriptor::TYPE_FIXED32:
311         return "uint32_t";
312       case FieldDescriptor::TYPE_FIXED64:
313         return "uint64_t";
314       case FieldDescriptor::TYPE_SFIXED32:
315         return "int32_t";
316       case FieldDescriptor::TYPE_SFIXED64:
317         return "int64_t";
318       case FieldDescriptor::TYPE_FLOAT:
319         return "float";
320       case FieldDescriptor::TYPE_DOUBLE:
321         return "double";
322       case FieldDescriptor::TYPE_ENUM:
323         return GetCppClassName(field->enum_type(),
324                                !HasSamePackage(field->enum_type()));
325       case FieldDescriptor::TYPE_STRING:
326       case FieldDescriptor::TYPE_BYTES:
327         return "std::string";
328       case FieldDescriptor::TYPE_MESSAGE:
329         return GetCppClassName(field->message_type(),
330                                !HasSamePackage(field->message_type()));
331       case FieldDescriptor::TYPE_GROUP:
332         Abort("Groups not supported.");
333         return "";
334     }
335     Abort("Unrecognized FieldDescriptor::Type.");
336     return "";
337   }
338 
FieldToRepetitionType(const FieldDescriptor * field)339   const char* FieldToRepetitionType(const FieldDescriptor* field) {
340     if (!field->is_repeated())
341       return "kNotRepeated";
342     if (field->is_packed())
343       return "kRepeatedPacked";
344     return "kRepeatedNotPacked";
345   }
346 
CollectDescriptors()347   void CollectDescriptors() {
348     // Collect message descriptors in DFS order.
349     std::vector<const Descriptor*> stack;
350     stack.reserve(static_cast<size_t>(source_->message_type_count()));
351     for (int i = 0; i < source_->message_type_count(); ++i)
352       stack.push_back(source_->message_type(i));
353 
354     while (!stack.empty()) {
355       const Descriptor* message = stack.back();
356       stack.pop_back();
357 
358       if (message->extension_count() > 0) {
359         if (message->field_count() > 0 || message->nested_type_count() > 0 ||
360             message->enum_type_count() > 0) {
361           Abort("message with extend blocks shouldn't contain anything else");
362         }
363 
364         // Iterate over all fields in "extend" blocks.
365         for (int i = 0; i < message->extension_count(); ++i) {
366           const FieldDescriptor* extension = message->extension(i);
367 
368           // Protoc plugin API does not group fields in "extend" blocks.
369           // As the support for extensions in protozero is limited, the code
370           // assumes that extend blocks are located inside a wrapper message and
371           // name of this message is used to group them.
372           std::string extension_name = extension->extension_scope()->name();
373           extensions_[extension_name].push_back(extension);
374 
375           if (extension->message_type()) {
376             // Emit a forward declaration of nested message types, as the outer
377             // class will refer to them when creating type aliases.
378             referenced_messages_.insert(extension->message_type());
379           }
380         }
381       } else {
382         messages_.push_back(message);
383         for (int i = 0; i < message->nested_type_count(); ++i) {
384           stack.push_back(message->nested_type(i));
385           // Emit a forward declaration of nested message types, as the outer
386           // class will refer to them when creating type aliases.
387           referenced_messages_.insert(message->nested_type(i));
388         }
389       }
390     }
391 
392     // Collect enums.
393     for (int i = 0; i < source_->enum_type_count(); ++i)
394       enums_.push_back(source_->enum_type(i));
395 
396     if (source_->extension_count() > 0) {
397       // TODO(b/336524288): emit field numbers
398     }
399 
400     for (const Descriptor* message : messages_) {
401       for (int i = 0; i < message->enum_type_count(); ++i) {
402         enums_.push_back(message->enum_type(i));
403       }
404     }
405   }
406 
CollectDependencies()407   void CollectDependencies() {
408     // Public import basically means that callers only need to import this
409     // proto in order to use the stuff publicly imported by this proto.
410     for (int i = 0; i < source_->public_dependency_count(); ++i)
411       public_imports_.insert(source_->public_dependency(i));
412 
413     if (source_->weak_dependency_count() > 0)
414       Abort("Weak imports are not supported.");
415 
416     // Validations. Collect public imports (of collected imports) in DFS order.
417     // Visibilty for current proto:
418     // - all imports listed in current proto,
419     // - public imports of everything imported (recursive).
420     std::vector<const FileDescriptor*> stack;
421     for (int i = 0; i < source_->dependency_count(); ++i) {
422       const FileDescriptor* import = source_->dependency(i);
423       stack.push_back(import);
424       if (public_imports_.count(import) == 0) {
425         private_imports_.insert(import);
426       }
427     }
428 
429     while (!stack.empty()) {
430       const FileDescriptor* import = stack.back();
431       stack.pop_back();
432       for (int i = 0; i < import->public_dependency_count(); ++i) {
433         stack.push_back(import->public_dependency(i));
434       }
435     }
436 
437     // Collect descriptors of messages and enums used in current proto.
438     // It will be used to generate necessary forward declarations and
439     // check that everything lays in the same namespace.
440     for (const Descriptor* message : messages_) {
441       for (int i = 0; i < message->field_count(); ++i) {
442         const FieldDescriptor* field = message->field(i);
443 
444         if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
445           if (public_imports_.count(field->message_type()->file()) == 0) {
446             // Avoid multiple forward declarations since
447             // public imports have been already included.
448             referenced_messages_.insert(field->message_type());
449           }
450         } else if (field->type() == FieldDescriptor::TYPE_ENUM) {
451           if (public_imports_.count(field->enum_type()->file()) == 0) {
452             referenced_enums_.insert(field->enum_type());
453           }
454         }
455       }
456     }
457   }
458 
Preprocess()459   void Preprocess() {
460     // Package name maps to a series of namespaces.
461     package_ = source_->package();
462     namespaces_ = SplitString(package_, ".");
463     if (!wrapper_namespace_.empty())
464       namespaces_.push_back(wrapper_namespace_);
465 
466     full_namespace_prefix_ = "::";
467     for (const std::string& ns : namespaces_)
468       full_namespace_prefix_ += ns + "::";
469 
470     CollectDescriptors();
471     CollectDependencies();
472   }
473 
GetNamespaceNameForInnerEnum(const EnumDescriptor * enumeration)474   std::string GetNamespaceNameForInnerEnum(const EnumDescriptor* enumeration) {
475     return "perfetto_pbzero_enum_" +
476            GetCppClassName(enumeration->containing_type());
477   }
478 
479   // Print top header, namespaces and forward declarations.
GeneratePrologue()480   void GeneratePrologue() {
481     std::string greeting =
482         "// Autogenerated by the ProtoZero compiler plugin. DO NOT EDIT.\n";
483     std::string guard = package_ + "_" + source_->name() + "_H_";
484     guard = ToUpper(guard);
485     guard = StripChars(guard, ".-/\\", '_');
486 
487     stub_h_->Print(
488         "$greeting$\n"
489         "#ifndef $guard$\n"
490         "#define $guard$\n\n"
491         "#include <stddef.h>\n"
492         "#include <stdint.h>\n\n",
493         "greeting", greeting, "guard", guard);
494 
495     if (sdk_mode_) {
496       stub_h_->Print("#include \"perfetto.h\"\n");
497     } else {
498       stub_h_->Print(
499           "#include \"perfetto/protozero/field_writer.h\"\n"
500           "#include \"perfetto/protozero/message.h\"\n"
501           "#include \"perfetto/protozero/packed_repeated_fields.h\"\n"
502           "#include \"perfetto/protozero/proto_decoder.h\"\n"
503           "#include \"perfetto/protozero/proto_utils.h\"\n");
504     }
505 
506     // Print includes for public imports. In sdk mode, all imports are assumed
507     // to be part of the sdk.
508     if (!sdk_mode_) {
509       for (const FileDescriptor* dependency : public_imports_) {
510         // Dependency name could contain slashes but importing from upper-level
511         // directories is not possible anyway since build system processes each
512         // proto file individually. Hence proto lookup path is always equal to
513         // the directory where particular proto file is located and protoc does
514         // not allow reference to upper directory (aka ..) in import path.
515         //
516         // Laconically said:
517         // - source_->name() may never have slashes,
518         // - dependency->name() may have slashes but always refers to inner
519         // path.
520         stub_h_->Print("#include \"$name$.h\"\n", "name",
521                        ProtoStubName(dependency));
522       }
523     }
524     stub_h_->Print("\n");
525 
526     PrintForwardDeclarations();
527 
528     // Print namespaces.
529     for (const std::string& ns : namespaces_) {
530       stub_h_->Print("namespace $ns$ {\n", "ns", ns);
531     }
532     stub_h_->Print("\n");
533   }
534 
PrintForwardDeclarations()535   void PrintForwardDeclarations() {
536     struct Descriptors {
537       std::vector<const Descriptor*> messages_;
538       std::vector<const EnumDescriptor*> enums_;
539     };
540     std::map<std::string, Descriptors> package_to_descriptors;
541 
542     for (const Descriptor* message : referenced_messages_) {
543       package_to_descriptors[message->file()->package()].messages_.push_back(
544           message);
545     }
546 
547     for (const EnumDescriptor* enumeration : referenced_enums_) {
548       package_to_descriptors[enumeration->file()->package()].enums_.push_back(
549           enumeration);
550     }
551 
552     for (const auto& [package, descriptors] : package_to_descriptors) {
553       std::vector<std::string> namespaces = SplitString(package, ".");
554       namespaces.push_back(wrapper_namespace_);
555 
556       // open namespaces
557       for (const auto& ns : namespaces) {
558         stub_h_->Print("namespace $ns$ {\n", "ns", ns);
559       }
560 
561       for (const Descriptor* message : descriptors.messages_) {
562         stub_h_->Print("class $class$;\n", "class", GetCppClassName(message));
563       }
564 
565       for (const EnumDescriptor* enumeration : descriptors.enums_) {
566         if (enumeration->containing_type()) {
567           stub_h_->Print("namespace $namespace_name$ {\n", "namespace_name",
568                          GetNamespaceNameForInnerEnum(enumeration));
569         }
570         stub_h_->Print("enum $class$ : int32_t;\n", "class",
571                        enumeration->name());
572 
573         if (enumeration->containing_type()) {
574           stub_h_->Print("}  // namespace $namespace_name$\n", "namespace_name",
575                          GetNamespaceNameForInnerEnum(enumeration));
576           stub_h_->Print("using $alias$ = $namespace_name$::$short_name$;\n",
577                          "alias", GetCppClassName(enumeration),
578                          "namespace_name",
579                          GetNamespaceNameForInnerEnum(enumeration),
580                          "short_name", enumeration->name());
581         }
582       }
583 
584       // close namespaces
585       for (auto it = namespaces.crbegin(); it != namespaces.crend(); ++it) {
586         stub_h_->Print("} // Namespace $ns$.\n", "ns", *it);
587       }
588     }
589 
590     stub_h_->Print("\n");
591   }
592 
GenerateEnumDescriptor(const EnumDescriptor * enumeration)593   void GenerateEnumDescriptor(const EnumDescriptor* enumeration) {
594     bool is_inner_enum = !!enumeration->containing_type();
595     if (is_inner_enum) {
596       stub_h_->Print("namespace $namespace_name$ {\n", "namespace_name",
597                      GetNamespaceNameForInnerEnum(enumeration));
598     }
599 
600     stub_h_->Print("enum $class$ : int32_t {\n", "class", enumeration->name());
601     stub_h_->Indent();
602 
603     std::string min_name, max_name;
604     int min_val = std::numeric_limits<int>::max();
605     int max_val = -1;
606     for (int i = 0; i < enumeration->value_count(); ++i) {
607       const EnumValueDescriptor* value = enumeration->value(i);
608       const std::string value_name = value->name();
609       stub_h_->Print("$name$ = $number$,\n", "name", value_name, "number",
610                      std::to_string(value->number()));
611       if (value->number() < min_val) {
612         min_val = value->number();
613         min_name = value_name;
614       }
615       if (value->number() > max_val) {
616         max_val = value->number();
617         max_name = value_name;
618       }
619     }
620     stub_h_->Outdent();
621     stub_h_->Print("};\n");
622     if (is_inner_enum) {
623       const std::string namespace_name =
624           GetNamespaceNameForInnerEnum(enumeration);
625       stub_h_->Print("} // namespace $namespace_name$\n", "namespace_name",
626                      namespace_name);
627       stub_h_->Print(
628           "using $full_enum_name$ = $namespace_name$::$enum_name$;\n\n",
629           "full_enum_name", GetCppClassName(enumeration), "enum_name",
630           enumeration->name(), "namespace_name", namespace_name);
631     }
632     stub_h_->Print("\n");
633     stub_h_->Print("constexpr $class$ $class$_MIN = $class$::$min$;\n", "class",
634                    GetCppClassName(enumeration), "min", min_name);
635     stub_h_->Print("constexpr $class$ $class$_MAX = $class$::$max$;\n", "class",
636                    GetCppClassName(enumeration), "max", max_name);
637     stub_h_->Print("\n");
638 
639     GenerateEnumToStringConversion(enumeration);
640   }
641 
GenerateEnumToStringConversion(const EnumDescriptor * enumeration)642   void GenerateEnumToStringConversion(const EnumDescriptor* enumeration) {
643     std::string fullClassName =
644         full_namespace_prefix_ + GetCppClassName(enumeration);
645     const char* function_header_stub = R"(
646 PERFETTO_PROTOZERO_CONSTEXPR14_OR_INLINE
647 const char* $class_name$_Name($full_class$ value) {
648 )";
649     stub_h_->Print(function_header_stub, "full_class", fullClassName,
650                    "class_name", GetCppClassName(enumeration));
651     stub_h_->Indent();
652     stub_h_->Print("switch (value) {");
653     for (int index = 0; index < enumeration->value_count(); ++index) {
654       const EnumValueDescriptor* value = enumeration->value(index);
655       const char* switch_stub = R"(
656 case $full_class$::$value_name$:
657   return "$value_name$";
658 )";
659       stub_h_->Print(switch_stub, "full_class", fullClassName, "value_name",
660                      value->name());
661     }
662     stub_h_->Print("}\n");
663     stub_h_->Print(R"(return "PBZERO_UNKNOWN_ENUM_VALUE";)");
664     stub_h_->Print("\n");
665     stub_h_->Outdent();
666     stub_h_->Print("}\n\n");
667   }
668 
669   // Packed repeated fields are encoded as a length-delimited field on the wire,
670   // where the payload is the concatenation of invidually encoded elements.
GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor * field)671   void GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor* field) {
672     std::map<std::string, std::string> setter;
673     setter["name"] = field->lowercase_name();
674     setter["field_metadata"] = GetFieldMetadataTypeName(field);
675     setter["action"] = "set";
676     setter["buffer_type"] = FieldTypeToPackedBufferType(field->type());
677     stub_h_->Print(
678         setter,
679         "void $action$_$name$(const $buffer_type$& packed_buffer) {\n"
680         "  AppendBytes($field_metadata$::kFieldId, packed_buffer.data(),\n"
681         "              packed_buffer.size());\n"
682         "}\n");
683   }
684 
GenerateSimpleFieldDescriptor(const FieldDescriptor * field)685   void GenerateSimpleFieldDescriptor(const FieldDescriptor* field) {
686     std::map<std::string, std::string> setter;
687     setter["id"] = std::to_string(field->number());
688     setter["name"] = field->lowercase_name();
689     setter["field_metadata"] = GetFieldMetadataTypeName(field);
690     setter["action"] = field->is_repeated() ? "add" : "set";
691     setter["cpp_type"] = FieldToCppTypeName(field);
692     setter["proto_field_type"] = FieldToProtoSchemaType(field);
693 
694     const char* code_stub =
695         "void $action$_$name$($cpp_type$ value) {\n"
696         "  static constexpr uint32_t field_id = $field_metadata$::kFieldId;\n"
697         "  // Call the appropriate protozero::Message::Append(field_id, ...)\n"
698         "  // method based on the type of the field.\n"
699         "  ::protozero::internal::FieldWriter<\n"
700         "    ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$>\n"
701         "      ::Append(*this, field_id, value);\n"
702         "}\n";
703 
704     if (field->type() == FieldDescriptor::TYPE_STRING) {
705       // Strings and bytes should have an additional accessor which specifies
706       // the length explicitly.
707       const char* additional_method =
708           "void $action$_$name$(const char* data, size_t size) {\n"
709           "  AppendBytes($field_metadata$::kFieldId, data, size);\n"
710           "}\n"
711           "void $action$_$name$(::protozero::ConstChars chars) {\n"
712           "  AppendBytes($field_metadata$::kFieldId, chars.data, chars.size);\n"
713           "}\n";
714       stub_h_->Print(setter, additional_method);
715     } else if (field->type() == FieldDescriptor::TYPE_BYTES) {
716       const char* additional_method =
717           "void $action$_$name$(const uint8_t* data, size_t size) {\n"
718           "  AppendBytes($field_metadata$::kFieldId, data, size);\n"
719           "}\n"
720           "void $action$_$name$(::protozero::ConstBytes bytes) {\n"
721           "  AppendBytes($field_metadata$::kFieldId, bytes.data, bytes.size);\n"
722           "}\n";
723       stub_h_->Print(setter, additional_method);
724     } else if (field->type() == FieldDescriptor::TYPE_GROUP ||
725                field->type() == FieldDescriptor::TYPE_MESSAGE) {
726       Abort("Unsupported field type.");
727       return;
728     }
729 
730     stub_h_->Print(setter, code_stub);
731   }
732 
GenerateNestedMessageFieldDescriptor(const FieldDescriptor * field)733   void GenerateNestedMessageFieldDescriptor(const FieldDescriptor* field) {
734     std::string action = field->is_repeated() ? "add" : "set";
735     std::string inner_class = GetCppClassName(
736         field->message_type(), !HasSamePackage(field->message_type()));
737     stub_h_->Print(
738         "template <typename T = $inner_class$> T* $action$_$name$() {\n"
739         "  return BeginNestedMessage<T>($id$);\n"
740         "}\n\n",
741         "id", std::to_string(field->number()), "name", field->lowercase_name(),
742         "action", action, "inner_class", inner_class);
743     if (field->options().lazy()) {
744       stub_h_->Print(
745           "void $action$_$name$_raw(const std::string& raw) {\n"
746           "  return AppendBytes($id$, raw.data(), raw.size());\n"
747           "}\n\n",
748           "id", std::to_string(field->number()), "name",
749           field->lowercase_name(), "action", action, "inner_class",
750           inner_class);
751     }
752   }
753 
GenerateDecoder(const Descriptor * message)754   void GenerateDecoder(const Descriptor* message) {
755     int max_field_id = 0;
756     bool has_nonpacked_repeated_fields = false;
757     for (int i = 0; i < message->field_count(); ++i) {
758       const FieldDescriptor* field = message->field(i);
759       if (field->number() > kMaxDecoderFieldId)
760         continue;
761       max_field_id = std::max(max_field_id, field->number());
762       if (field->is_repeated() && !field->is_packed())
763         has_nonpacked_repeated_fields = true;
764     }
765     // Iterate over all fields in "extend" blocks.
766     for (int i = 0; i < message->extension_range_count(); ++i) {
767       Descriptor::ExtensionRange::Proto range;
768       message->extension_range(i)->CopyTo(&range);
769       int candidate = range.end() - 1;
770       if (candidate > kMaxDecoderFieldId)
771         continue;
772       max_field_id = std::max(max_field_id, candidate);
773     }
774 
775     std::string class_name = GetCppClassName(message) + "_Decoder";
776     stub_h_->Print(
777         "class $name$ : public "
778         "::protozero::TypedProtoDecoder</*MAX_FIELD_ID=*/$max$, "
779         "/*HAS_NONPACKED_REPEATED_FIELDS=*/$rep$> {\n",
780         "name", class_name, "max", std::to_string(max_field_id), "rep",
781         has_nonpacked_repeated_fields ? "true" : "false");
782     stub_h_->Print(" public:\n");
783     stub_h_->Indent();
784     stub_h_->Print(
785         "$name$(const uint8_t* data, size_t len) "
786         ": TypedProtoDecoder(data, len) {}\n",
787         "name", class_name);
788     stub_h_->Print(
789         "explicit $name$(const std::string& raw) : "
790         "TypedProtoDecoder(reinterpret_cast<const uint8_t*>(raw.data()), "
791         "raw.size()) {}\n",
792         "name", class_name);
793     stub_h_->Print(
794         "explicit $name$(const ::protozero::ConstBytes& raw) : "
795         "TypedProtoDecoder(raw.data, raw.size) {}\n",
796         "name", class_name);
797 
798     for (int i = 0; i < message->field_count(); ++i) {
799       const FieldDescriptor* field = message->field(i);
800       if (field->number() > max_field_id) {
801         stub_h_->Print("// field $name$ omitted because its id is too high\n",
802                        "name", field->name());
803         continue;
804       }
805       std::string getter;
806       std::string cpp_type;
807       switch (field->type()) {
808         case FieldDescriptor::TYPE_BOOL:
809           getter = "as_bool";
810           cpp_type = "bool";
811           break;
812         case FieldDescriptor::TYPE_SFIXED32:
813         case FieldDescriptor::TYPE_INT32:
814           getter = "as_int32";
815           cpp_type = "int32_t";
816           break;
817         case FieldDescriptor::TYPE_SINT32:
818           getter = "as_sint32";
819           cpp_type = "int32_t";
820           break;
821         case FieldDescriptor::TYPE_SFIXED64:
822         case FieldDescriptor::TYPE_INT64:
823           getter = "as_int64";
824           cpp_type = "int64_t";
825           break;
826         case FieldDescriptor::TYPE_SINT64:
827           getter = "as_sint64";
828           cpp_type = "int64_t";
829           break;
830         case FieldDescriptor::TYPE_FIXED32:
831         case FieldDescriptor::TYPE_UINT32:
832           getter = "as_uint32";
833           cpp_type = "uint32_t";
834           break;
835         case FieldDescriptor::TYPE_FIXED64:
836         case FieldDescriptor::TYPE_UINT64:
837           getter = "as_uint64";
838           cpp_type = "uint64_t";
839           break;
840         case FieldDescriptor::TYPE_FLOAT:
841           getter = "as_float";
842           cpp_type = "float";
843           break;
844         case FieldDescriptor::TYPE_DOUBLE:
845           getter = "as_double";
846           cpp_type = "double";
847           break;
848         case FieldDescriptor::TYPE_ENUM:
849           getter = "as_int32";
850           cpp_type = "int32_t";
851           break;
852         case FieldDescriptor::TYPE_STRING:
853           getter = "as_string";
854           cpp_type = "::protozero::ConstChars";
855           break;
856         case FieldDescriptor::TYPE_MESSAGE:
857         case FieldDescriptor::TYPE_BYTES:
858           getter = "as_bytes";
859           cpp_type = "::protozero::ConstBytes";
860           break;
861         case FieldDescriptor::TYPE_GROUP:
862           continue;
863       }
864 
865       stub_h_->Print("bool has_$name$() const { return at<$id$>().valid(); }\n",
866                      "name", field->lowercase_name(), "id",
867                      std::to_string(field->number()));
868 
869       if (field->is_packed()) {
870         const char* protozero_wire_type =
871             FieldTypeToProtozeroWireType(field->type());
872         stub_h_->Print(
873             "::protozero::PackedRepeatedFieldIterator<$wire_type$, $cpp_type$> "
874             "$name$(bool* parse_error_ptr) const { return "
875             "GetPackedRepeated<$wire_type$, $cpp_type$>($id$, "
876             "parse_error_ptr); }\n",
877             "wire_type", protozero_wire_type, "cpp_type", cpp_type, "name",
878             field->lowercase_name(), "id", std::to_string(field->number()));
879       } else if (field->is_repeated()) {
880         stub_h_->Print(
881             "::protozero::RepeatedFieldIterator<$cpp_type$> $name$() const { "
882             "return "
883             "GetRepeated<$cpp_type$>($id$); }\n",
884             "name", field->lowercase_name(), "cpp_type", cpp_type, "id",
885             std::to_string(field->number()));
886       } else {
887         stub_h_->Print(
888             "$cpp_type$ $name$() const { return at<$id$>().$getter$(); }\n",
889             "name", field->lowercase_name(), "id",
890             std::to_string(field->number()), "cpp_type", cpp_type, "getter",
891             getter);
892       }
893     }
894     stub_h_->Outdent();
895     stub_h_->Print("};\n\n");
896   }
897 
GenerateConstantsForMessageFields(const Descriptor * message)898   void GenerateConstantsForMessageFields(const Descriptor* message) {
899     const bool has_fields =
900         message->field_count() > 0 || message->extension_count() > 0;
901 
902     // Field number constants.
903     if (has_fields) {
904       stub_h_->Print("enum : int32_t {\n");
905       stub_h_->Indent();
906 
907       for (int i = 0; i < message->field_count(); ++i) {
908         const FieldDescriptor* field = message->field(i);
909         stub_h_->Print("$name$ = $id$,\n", "name",
910                        GetFieldNumberConstant(field), "id",
911                        std::to_string(field->number()));
912       }
913 
914       for (int i = 0; i < message->extension_count(); ++i) {
915         const FieldDescriptor* field = message->extension(i);
916 
917         stub_h_->Print("$name$ = $id$,\n", "name",
918                        GetFieldNumberConstant(field), "id",
919                        std::to_string(field->number()));
920       }
921 
922       stub_h_->Outdent();
923       stub_h_->Print("};\n");
924     }
925   }
926 
GenerateMessageDescriptor(const Descriptor * message)927   void GenerateMessageDescriptor(const Descriptor* message) {
928     GenerateDecoder(message);
929 
930     stub_h_->Print(
931         "class $name$ : public ::protozero::Message {\n"
932         " public:\n",
933         "name", GetCppClassName(message));
934     stub_h_->Indent();
935 
936     stub_h_->Print("using Decoder = $name$_Decoder;\n", "name",
937                    GetCppClassName(message));
938 
939     GenerateConstantsForMessageFields(message);
940 
941     stub_h_->Print(
942         "static constexpr const char* GetName() { return \".$name$\"; }\n\n",
943         "name", message->full_name());
944 
945     // Using statements for nested messages.
946     for (int i = 0; i < message->nested_type_count(); ++i) {
947       const Descriptor* nested_message = message->nested_type(i);
948       stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
949                      nested_message->name(), "global_name",
950                      GetCppClassName(nested_message, true));
951     }
952 
953     // Using statements for nested enums.
954     for (int i = 0; i < message->enum_type_count(); ++i) {
955       const EnumDescriptor* nested_enum = message->enum_type(i);
956       const char* stub = R"(
957 using $local_name$ = $global_name$;
958 static inline const char* $local_name$_Name($local_name$ value) {
959   return $global_name$_Name(value);
960 }
961 )";
962       stub_h_->Print(stub, "local_name", nested_enum->name(), "global_name",
963                      GetCppClassName(nested_enum, true));
964     }
965 
966     // Values of nested enums.
967     for (int i = 0; i < message->enum_type_count(); ++i) {
968       const EnumDescriptor* nested_enum = message->enum_type(i);
969 
970       for (int j = 0; j < nested_enum->value_count(); ++j) {
971         const EnumValueDescriptor* value = nested_enum->value(j);
972         stub_h_->Print(
973             "static inline const $class$ $name$ = $class$::$name$;\n", "class",
974             nested_enum->name(), "name", value->name());
975       }
976     }
977 
978     // Field descriptors.
979     for (int i = 0; i < message->field_count(); ++i) {
980       GenerateFieldDescriptor(GetCppClassName(message), message->field(i));
981     }
982 
983     stub_h_->Outdent();
984     stub_h_->Print("};\n\n");
985   }
986 
GetFieldMetadataTypeName(const FieldDescriptor * field)987   std::string GetFieldMetadataTypeName(const FieldDescriptor* field) {
988     std::string name = field->camelcase_name();
989     if (isalpha(name[0]))
990       name[0] = static_cast<char>(toupper(name[0]));
991     return "FieldMetadata_" + name;
992   }
993 
GetFieldMetadataVariableName(const FieldDescriptor * field)994   std::string GetFieldMetadataVariableName(const FieldDescriptor* field) {
995     std::string name = field->camelcase_name();
996     if (isalpha(name[0]))
997       name[0] = static_cast<char>(toupper(name[0]));
998     return "k" + name;
999   }
1000 
GenerateFieldMetadata(const std::string & message_cpp_type,const FieldDescriptor * field)1001   void GenerateFieldMetadata(const std::string& message_cpp_type,
1002                              const FieldDescriptor* field) {
1003     const char* code_stub = R"(
1004 using $field_metadata_type$ =
1005   ::protozero::proto_utils::FieldMetadata<
1006     $field_id$,
1007     ::protozero::proto_utils::RepetitionType::$repetition_type$,
1008     ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$,
1009     $cpp_type$,
1010     $message_cpp_type$>;
1011 
1012 static constexpr $field_metadata_type$ $field_metadata_var${};
1013 )";
1014 
1015     stub_h_->Print(code_stub, "field_id", std::to_string(field->number()),
1016                    "repetition_type", FieldToRepetitionType(field),
1017                    "proto_field_type", FieldToProtoSchemaType(field),
1018                    "cpp_type", FieldToCppTypeName(field), "message_cpp_type",
1019                    message_cpp_type, "field_metadata_type",
1020                    GetFieldMetadataTypeName(field), "field_metadata_var",
1021                    GetFieldMetadataVariableName(field));
1022   }
1023 
GenerateFieldDescriptor(const std::string & message_cpp_type,const FieldDescriptor * field)1024   void GenerateFieldDescriptor(const std::string& message_cpp_type,
1025                                const FieldDescriptor* field) {
1026     GenerateFieldMetadata(message_cpp_type, field);
1027     if (field->is_packed()) {
1028       GeneratePackedRepeatedFieldDescriptor(field);
1029     } else if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
1030       GenerateSimpleFieldDescriptor(field);
1031     } else {
1032       GenerateNestedMessageFieldDescriptor(field);
1033     }
1034   }
1035 
1036   // Generate extension class for a group of FieldDescriptor instances
1037   // representing one "extend" block in proto definition. For example:
1038   //
1039   //   message SpecificExtension {
1040   //     extend GeneralThing {
1041   //       optional Fizz fizz = 101;
1042   //       optional Buzz buzz = 102;
1043   //     }
1044   //   }
1045   //
1046   // This is going to be passed as a vector of two elements, "fizz" and
1047   // "buzz". Wrapping message is used to provide a name for generated
1048   // extension class.
1049   //
1050   // In the example above, generated code is going to look like:
1051   //
1052   //   class SpecificExtension : public GeneralThing {
1053   //     Fizz* set_fizz();
1054   //     Buzz* set_buzz();
1055   //   }
GenerateExtension(const std::string & extension_name,const std::vector<const FieldDescriptor * > & descriptors)1056   void GenerateExtension(
1057       const std::string& extension_name,
1058       const std::vector<const FieldDescriptor*>& descriptors) {
1059     // Use an arbitrary descriptor in order to get generic information not
1060     // specific to any of them.
1061     const FieldDescriptor* descriptor = descriptors[0];
1062     const Descriptor* base_message = descriptor->containing_type();
1063 
1064     // TODO(ddrone): ensure that this code works when containing_type located in
1065     // other file or namespace.
1066     stub_h_->Print("class $name$ : public $extendee$ {\n", "name",
1067                    extension_name, "extendee",
1068                    GetCppClassName(base_message, /*full=*/true));
1069     stub_h_->Print(" public:\n");
1070     stub_h_->Indent();
1071     for (const FieldDescriptor* field : descriptors) {
1072       if (field->containing_type() != base_message) {
1073         Abort("one wrapper should extend only one message");
1074         return;
1075       }
1076       GenerateFieldDescriptor(extension_name, field);
1077     }
1078 
1079     if (!descriptors.empty()) {
1080       stub_h_->Print("enum : int32_t {\n");
1081       stub_h_->Indent();
1082 
1083       for (const FieldDescriptor* field : descriptors) {
1084         stub_h_->Print("$name$ = $id$,\n", "name",
1085                        GetFieldNumberConstant(field), "id",
1086                        std::to_string(field->number()));
1087       }
1088       stub_h_->Outdent();
1089       stub_h_->Print("};\n");
1090     }
1091 
1092     stub_h_->Outdent();
1093     stub_h_->Print("};\n");
1094   }
1095 
GenerateEpilogue()1096   void GenerateEpilogue() {
1097     for (unsigned i = 0; i < namespaces_.size(); ++i) {
1098       stub_h_->Print("} // Namespace.\n");
1099     }
1100     stub_h_->Print("#endif  // Include guard.\n");
1101   }
1102 
1103   const FileDescriptor* const source_;
1104   Printer* const stub_h_;
1105   std::string error_;
1106 
1107   std::string package_;
1108   std::string wrapper_namespace_;
1109   std::vector<std::string> namespaces_;
1110   std::string full_namespace_prefix_;
1111   std::vector<const Descriptor*> messages_;
1112   std::vector<const EnumDescriptor*> enums_;
1113   std::map<std::string, std::vector<const FieldDescriptor*>> extensions_;
1114 
1115   // Generate headers that can be used with the Perfetto SDK.
1116   bool sdk_mode_ = false;
1117 
1118   // The custom *Comp comparators are to ensure determinism of the generator.
1119   std::set<const FileDescriptor*, FileDescriptorComp> public_imports_;
1120   std::set<const FileDescriptor*, FileDescriptorComp> private_imports_;
1121   std::set<const Descriptor*, DescriptorComp> referenced_messages_;
1122   std::set<const EnumDescriptor*, EnumDescriptorComp> referenced_enums_;
1123 };
1124 
1125 class ProtoZeroGenerator : public ::google::protobuf::compiler::CodeGenerator {
1126  public:
1127   explicit ProtoZeroGenerator();
1128   ~ProtoZeroGenerator() override;
1129 
1130   // CodeGenerator implementation
1131   bool Generate(const google::protobuf::FileDescriptor* file,
1132                 const std::string& options,
1133                 GeneratorContext* context,
1134                 std::string* error) const override;
1135 };
1136 
ProtoZeroGenerator()1137 ProtoZeroGenerator::ProtoZeroGenerator() {}
1138 
~ProtoZeroGenerator()1139 ProtoZeroGenerator::~ProtoZeroGenerator() {}
1140 
Generate(const FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const1141 bool ProtoZeroGenerator::Generate(const FileDescriptor* file,
1142                                   const std::string& options,
1143                                   GeneratorContext* context,
1144                                   std::string* error) const {
1145   const std::unique_ptr<ZeroCopyOutputStream> stub_h_file_stream(
1146       context->Open(ProtoStubName(file) + ".h"));
1147   const std::unique_ptr<ZeroCopyOutputStream> stub_cc_file_stream(
1148       context->Open(ProtoStubName(file) + ".cc"));
1149 
1150   // Variables are delimited by $.
1151   Printer stub_h_printer(stub_h_file_stream.get(), '$');
1152   GeneratorJob job(file, &stub_h_printer);
1153 
1154   Printer stub_cc_printer(stub_cc_file_stream.get(), '$');
1155   stub_cc_printer.Print("// Intentionally empty (crbug.com/998165)\n");
1156 
1157   // Parse additional options.
1158   for (const std::string& option : SplitString(options, ",")) {
1159     std::vector<std::string> option_pair = SplitString(option, "=");
1160     job.SetOption(option_pair[0], option_pair[1]);
1161   }
1162 
1163   if (!job.GenerateStubs()) {
1164     *error = job.GetFirstError();
1165     return false;
1166   }
1167   return true;
1168 }
1169 
1170 }  // namespace
1171 }  // namespace protozero
1172 
main(int argc,char * argv[])1173 int main(int argc, char* argv[]) {
1174   ::protozero::ProtoZeroGenerator generator;
1175   return google::protobuf::compiler::PluginMain(argc, argv, &generator);
1176 }
1177