xref: /aosp_15_r20/external/bazelbuild-rules_rust/proto/prost/private/protoc_wrapper.rs (revision d4726bddaa87cc4778e7472feed243fa4b6c267f)
1 //! A process wrapper for running a Protobuf compiler configured for Prost or Tonic output in a Bazel rule.
2 
3 use std::collections::BTreeMap;
4 use std::collections::BTreeSet;
5 use std::fmt::{Display, Formatter, Write};
6 use std::fs;
7 use std::io::BufRead;
8 use std::path::Path;
9 use std::path::PathBuf;
10 use std::process;
11 use std::{env, fmt};
12 
13 use heck::{ToSnakeCase, ToUpperCamelCase};
14 use prost::Message;
15 use prost_types::{
16     DescriptorProto, EnumDescriptorProto, FileDescriptorProto, FileDescriptorSet,
17     OneofDescriptorProto,
18 };
19 
20 /// Locate prost outputs in the protoc output directory.
find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf>21 fn find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf> {
22     let mut all_rs_files: BTreeSet<PathBuf> = BTreeSet::new();
23     for entry in fs::read_dir(out_dir).expect("Failed to read directory") {
24         let entry = entry.expect("Failed to read entry");
25         let path = entry.path();
26         if path.is_dir() {
27             for f in find_generated_rust_files(&path) {
28                 all_rs_files.insert(f);
29             }
30         } else if let Some(ext) = path.extension() {
31             if ext == "rs" {
32                 all_rs_files.insert(path);
33             }
34         } else if let Some(name) = path.file_name() {
35             // The filename is set to `_` when the package name is empty.
36             if name == "_" {
37                 let rs_name = path.parent().expect("Failed to get parent").join("_.rs");
38                 fs::rename(&path, &rs_name).unwrap_or_else(|err| {
39                     panic!("Failed to rename file: {err:?}: {path:?} -> {rs_name:?}")
40                 });
41                 all_rs_files.insert(rs_name);
42             }
43         }
44     }
45 
46     all_rs_files
47 }
48 
snake_cased_package_name(package: &str) -> String49 fn snake_cased_package_name(package: &str) -> String {
50     if package == "_" {
51         return package.to_owned();
52     }
53 
54     package
55         .split('.')
56         .map(|s| s.to_snake_case())
57         .collect::<Vec<_>>()
58         .join(".")
59 }
60 
61 /// Rust module definition.
62 #[derive(Debug, Default)]
63 struct Module {
64     /// The name of the module.
65     name: String,
66 
67     /// The contents of the module.
68     contents: String,
69 
70     /// The names of any other modules which are submodules of this module.
71     submodules: BTreeSet<String>,
72 }
73 
74 /// Generate a lib.rs file with all prost/tonic outputs embeeded in modules which
75 /// mirror the proto packages. For the example proto file we would expect to see
76 /// the Rust output that follows it.
77 ///
78 /// ```proto
79 /// syntax = "proto3";
80 /// package examples.prost.helloworld;
81 ///
82 /// message HelloRequest {
83 ///     // Request message contains the name to be greeted
84 ///     string name = 1;
85 /// }
86 //
87 /// message HelloReply {
88 ///     // Reply contains the greeting message
89 ///     string message = 1;
90 /// }
91 /// ```
92 ///
93 /// This is expected to render out to something like the following. Note that
94 /// formatting is not applied so indentation may be missing in the actual output.
95 ///
96 /// ```ignore
97 /// pub mod examples {
98 ///     pub mod prost {
99 ///         pub mod helloworld {
100 ///             // @generated
101 ///             #[allow(clippy::derive_partial_eq_without_eq)]
102 ///             #[derive(Clone, PartialEq, ::prost::Message)]
103 ///             pub struct HelloRequest {
104 ///                 /// Request message contains the name to be greeted
105 ///                 #[prost(string, tag = "1")]
106 ///                 pub name: ::prost::alloc::string::String,
107 ///             }
108 ///             #[allow(clippy::derive_partial_eq_without_eq)]
109 ///             #[derive(Clone, PartialEq, ::prost::Message)]
110 ///             pub struct HelloReply {
111 ///                 /// Reply contains the greeting message
112 ///                 #[prost(string, tag = "1")]
113 ///                 pub message: ::prost::alloc::string::String,
114 ///             }
115 ///             // @protoc_insertion_point(module)
116 ///         }
117 ///     }
118 /// }
119 /// ```
generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String120 fn generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String {
121     let mut module_info = BTreeMap::new();
122 
123     for path in prost_outputs.iter() {
124         let mut package = path
125             .file_stem()
126             .expect("Failed to get file stem")
127             .to_str()
128             .expect("Failed to convert to str")
129             .to_string();
130 
131         if is_tonic {
132             package = package
133                 .strip_suffix(".tonic")
134                 .expect("Failed to strip suffix")
135                 .to_string()
136         };
137 
138         if package.is_empty() {
139             continue;
140         }
141 
142         let name = if package == "_" {
143             package.clone()
144         } else if package.contains('.') {
145             package
146                 .rsplit_once('.')
147                 .expect("Failed to split on '.'")
148                 .1
149                 .to_snake_case()
150                 .to_string()
151         } else {
152             package.to_snake_case()
153         };
154 
155         // Avoid a stack overflow by skipping a known bad package name
156         let module_name = snake_cased_package_name(&package);
157 
158         module_info.insert(
159             module_name.clone(),
160             Module {
161                 name,
162                 contents: fs::read_to_string(path).expect("Failed to read file"),
163                 submodules: BTreeSet::new(),
164             },
165         );
166 
167         let module_parts = module_name.split('.').collect::<Vec<_>>();
168         for parent_module_index in 0..module_parts.len() {
169             let child_module_index = parent_module_index + 1;
170             if child_module_index >= module_parts.len() {
171                 break;
172             }
173             let full_parent_module_name = module_parts[0..parent_module_index + 1].join(".");
174             let parent_module_name = module_parts[parent_module_index];
175             let child_module_name = module_parts[child_module_index];
176 
177             module_info
178                 .entry(full_parent_module_name.clone())
179                 .and_modify(|parent_module| {
180                     parent_module
181                         .submodules
182                         .insert(child_module_name.to_string());
183                 })
184                 .or_insert(Module {
185                     name: parent_module_name.to_string(),
186                     contents: "".to_string(),
187                     submodules: [child_module_name.to_string()].iter().cloned().collect(),
188                 });
189         }
190     }
191 
192     let mut content = "// @generated\n\n".to_string();
193     write_module(&mut content, &module_info, "", 0);
194     content
195 }
196 
197 /// Write out a rust module and all of its submodules.
write_module( content: &mut String, module_info: &BTreeMap<String, Module>, module_name: &str, depth: usize, )198 fn write_module(
199     content: &mut String,
200     module_info: &BTreeMap<String, Module>,
201     module_name: &str,
202     depth: usize,
203 ) {
204     if module_name.is_empty() {
205         for submodule_name in module_info.keys() {
206             write_module(content, module_info, submodule_name, depth + 1);
207         }
208         return;
209     }
210     let module = module_info.get(module_name).expect("Failed to get module");
211     let indent = "  ".repeat(depth);
212     let is_rust_module = module.name != "_";
213 
214     if is_rust_module {
215         let rust_module_name = escape_keyword(module.name.clone());
216         content
217             .write_str(&format!("{}pub mod {} {{\n", indent, rust_module_name))
218             .expect("Failed to write string");
219     }
220 
221     content
222         .write_str(&module.contents)
223         .expect("Failed to write string");
224 
225     for submodule_name in module.submodules.iter() {
226         write_module(
227             content,
228             module_info,
229             [module_name, submodule_name].join(".").as_str(),
230             depth + 1,
231         );
232     }
233 
234     if is_rust_module {
235         content
236             .write_str(&format!("{}}}\n", indent))
237             .expect("Failed to write string");
238     }
239 }
240 
241 /// ProtoPath is a path to a proto message, enum, or oneof.
242 ///
243 /// Example: `helloworld.Greeter.HelloRequest`
244 #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
245 struct ProtoPath(String);
246 
247 impl ProtoPath {
248     /// Join a component to the end of the path.
join(&self, component: &str) -> ProtoPath249     fn join(&self, component: &str) -> ProtoPath {
250         if self.0.is_empty() {
251             return ProtoPath(component.to_string());
252         }
253         if component.is_empty() {
254             return self.clone();
255         }
256 
257         ProtoPath(format!("{}.{}", self.0, component))
258     }
259 }
260 
261 impl Display for ProtoPath {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result262     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
263         write!(f, "{}", self.0)
264     }
265 }
266 
267 impl From<&str> for ProtoPath {
from(path: &str) -> Self268     fn from(path: &str) -> Self {
269         ProtoPath(path.to_string())
270     }
271 }
272 
273 /// RustModulePath is a path to a rust module.
274 ///
275 /// Example: `helloworld::greeter::HelloRequest`
276 #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
277 struct RustModulePath(String);
278 
279 impl RustModulePath {
280     /// Join a path to the end of the module path.
join(&self, path: &str) -> RustModulePath281     fn join(&self, path: &str) -> RustModulePath {
282         if self.0.is_empty() {
283             return RustModulePath(path.to_string());
284         }
285         if path.is_empty() {
286             return self.clone();
287         }
288 
289         RustModulePath(format!("{}::{}", self.0, path))
290     }
291 }
292 
293 impl Display for RustModulePath {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result294     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
295         write!(f, "{}", self.0)
296     }
297 }
298 
299 impl From<&str> for RustModulePath {
from(path: &str) -> Self300     fn from(path: &str) -> Self {
301         RustModulePath(path.to_string())
302     }
303 }
304 
305 /// Compute the `--extern_path` flags for a list of proto files. This is
306 /// expected to convert proto files into a BTreeMap of
307 /// `example.prost.helloworld`: `crate_name::example::prost::helloworld`.
get_extern_paths( descriptor_set: &FileDescriptorSet, crate_name: &str, ) -> Result<BTreeMap<ProtoPath, RustModulePath>, String>308 fn get_extern_paths(
309     descriptor_set: &FileDescriptorSet,
310     crate_name: &str,
311 ) -> Result<BTreeMap<ProtoPath, RustModulePath>, String> {
312     let mut extern_paths = BTreeMap::new();
313     let rust_path = RustModulePath(crate_name.to_string());
314 
315     for file in descriptor_set.file.iter() {
316         descriptor_set_file_to_extern_paths(&mut extern_paths, &rust_path, file);
317     }
318 
319     Ok(extern_paths)
320 }
321 
322 /// Add the extern_path pairs for a file descriptor type.
descriptor_set_file_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, rust_path: &RustModulePath, file: &FileDescriptorProto, )323 fn descriptor_set_file_to_extern_paths(
324     extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
325     rust_path: &RustModulePath,
326     file: &FileDescriptorProto,
327 ) {
328     let package = file.package.clone().unwrap_or_default();
329     let rust_path = rust_path.join(&snake_cased_package_name(&package).replace('.', "::"));
330     let proto_path = ProtoPath(package);
331 
332     for message_type in file.message_type.iter() {
333         message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, message_type);
334     }
335 
336     for enum_type in file.enum_type.iter() {
337         enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
338     }
339 }
340 
341 /// Add the extern_path pairs for a message descriptor type.
message_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, message_type: &DescriptorProto, )342 fn message_type_to_extern_paths(
343     extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
344     proto_path: &ProtoPath,
345     rust_path: &RustModulePath,
346     message_type: &DescriptorProto,
347 ) {
348     let message_type_name = message_type
349         .name
350         .as_ref()
351         .expect("Failed to get message type name");
352 
353     extern_paths.insert(
354         proto_path.join(message_type_name),
355         rust_path.join(&message_type_name.to_upper_camel_case()),
356     );
357 
358     let name_lower = message_type_name.to_lowercase();
359     let proto_path = proto_path.join(&name_lower);
360     let rust_path = rust_path.join(&name_lower);
361 
362     for nested_type in message_type.nested_type.iter() {
363         message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, nested_type)
364     }
365 
366     for enum_type in message_type.enum_type.iter() {
367         enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
368     }
369 
370     for oneof_type in message_type.oneof_decl.iter() {
371         oneof_type_to_extern_paths(extern_paths, &proto_path, &rust_path, oneof_type);
372     }
373 }
374 
375 /// Add the extern_path pairs for an enum type.
enum_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, enum_type: &EnumDescriptorProto, )376 fn enum_type_to_extern_paths(
377     extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
378     proto_path: &ProtoPath,
379     rust_path: &RustModulePath,
380     enum_type: &EnumDescriptorProto,
381 ) {
382     let enum_type_name = enum_type
383         .name
384         .as_ref()
385         .expect("Failed to get enum type name");
386     extern_paths.insert(
387         proto_path.join(enum_type_name),
388         rust_path.join(enum_type_name),
389     );
390 }
391 
oneof_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, oneof_type: &OneofDescriptorProto, )392 fn oneof_type_to_extern_paths(
393     extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
394     proto_path: &ProtoPath,
395     rust_path: &RustModulePath,
396     oneof_type: &OneofDescriptorProto,
397 ) {
398     let oneof_type_name = oneof_type
399         .name
400         .as_ref()
401         .expect("Failed to get oneof type name");
402     extern_paths.insert(
403         proto_path.join(oneof_type_name),
404         rust_path.join(oneof_type_name),
405     );
406 }
407 
408 /// The parsed command-line arguments.
409 struct Args {
410     /// The path to the protoc binary.
411     protoc: PathBuf,
412 
413     /// The path to the output directory.
414     out_dir: PathBuf,
415 
416     /// The name of the crate.
417     crate_name: String,
418 
419     /// The bazel label.
420     label: String,
421 
422     /// The path to the package info file.
423     package_info_file: PathBuf,
424 
425     /// The proto files to compile.
426     proto_files: Vec<PathBuf>,
427 
428     /// The include directories.
429     includes: Vec<String>,
430 
431     /// Dependency descriptor sets.
432     descriptor_set: PathBuf,
433 
434     /// The path to the generated lib.rs file.
435     out_librs: PathBuf,
436 
437     /// The proto include paths.
438     proto_paths: Vec<String>,
439 
440     /// The path to the rustfmt binary.
441     rustfmt: Option<PathBuf>,
442 
443     /// Whether to generate tonic code.
444     is_tonic: bool,
445 
446     /// Extra arguments to pass to protoc.
447     extra_args: Vec<String>,
448 }
449 
450 impl Args {
451     /// Parse the command-line arguments.
parse() -> Result<Args, String>452     fn parse() -> Result<Args, String> {
453         let mut protoc: Option<PathBuf> = None;
454         let mut out_dir: Option<PathBuf> = None;
455         let mut crate_name: Option<String> = None;
456         let mut package_info_file: Option<PathBuf> = None;
457         let mut proto_files: Vec<PathBuf> = Vec::new();
458         let mut includes = Vec::new();
459         let mut descriptor_set = None;
460         let mut out_librs: Option<PathBuf> = None;
461         let mut rustfmt: Option<PathBuf> = None;
462         let mut proto_paths = Vec::new();
463         let mut label: Option<String> = None;
464         let mut tonic_or_prost_opts = Vec::new();
465         let mut is_tonic = false;
466 
467         let mut extra_args = Vec::new();
468 
469         let mut handle_arg = |arg: String| {
470             if !arg.starts_with('-') {
471                 proto_files.push(PathBuf::from(arg));
472                 return;
473             }
474 
475             if arg.starts_with("-I") {
476                 includes.push(
477                     arg.strip_prefix("-I")
478                         .expect("Failed to strip -I")
479                         .to_string(),
480                 );
481                 return;
482             }
483 
484             if arg == "--is_tonic" {
485                 is_tonic = true;
486                 return;
487             }
488 
489             if !arg.contains('=') {
490                 extra_args.push(arg);
491                 return;
492             }
493 
494             let parts = arg.split_once('=').expect("Failed to split argument on =");
495             match parts {
496                 ("--protoc", value) => {
497                     protoc = Some(PathBuf::from(value));
498                 }
499                 ("--prost_out", value) => {
500                     out_dir = Some(PathBuf::from(value));
501                 }
502                 ("--package_info_output", value) => {
503                     let (key, value) = value
504                         .split_once('=')
505                         .map(|(a, b)| (a.to_string(), PathBuf::from(b)))
506                         .expect("Failed to parse package info output");
507                     crate_name = Some(key);
508                     package_info_file = Some(value);
509                 }
510                 ("--deps_info", value) => {
511                     for line in fs::read_to_string(value)
512                         .expect("Failed to read file")
513                         .lines()
514                     {
515                         let path = PathBuf::from(line.trim());
516                         for flag in fs::read_to_string(path)
517                             .expect("Failed to read file")
518                             .lines()
519                         {
520                             tonic_or_prost_opts.push(format!("extern_path={}", flag.trim()));
521                         }
522                     }
523                 }
524                 ("--descriptor_set", value) => {
525                     descriptor_set = Some(PathBuf::from(value));
526                 }
527                 ("--out_librs", value) => {
528                     out_librs = Some(PathBuf::from(value));
529                 }
530                 ("--rustfmt", value) => {
531                     rustfmt = Some(PathBuf::from(value));
532                 }
533                 ("--proto_path", value) => {
534                     proto_paths.push(value.to_string());
535                 }
536                 ("--label", value) => {
537                     label = Some(value.to_string());
538                 }
539                 (arg, value) => {
540                     extra_args.push(format!("{}={}", arg, value));
541                 }
542             }
543         };
544 
545         // Iterate over the given command line arguments parsing out arguments
546         // for the process runner and arguments for protoc and potentially spawn
547         // additional arguments needed by prost.
548         for arg in env::args().skip(1) {
549             if let Some(path) = arg.strip_prefix('@') {
550                 // handle argfile
551                 let file = std::fs::File::open(path)
552                     .map_err(|_| format!("could not open argfile: {}", arg))?;
553                 for line in std::io::BufReader::new(file).lines() {
554                     handle_arg(line.map_err(|_| format!("could not read argfile: {}", arg))?);
555                 }
556             } else {
557                 handle_arg(arg);
558             }
559         }
560 
561         for tonic_or_prost_opt in tonic_or_prost_opts {
562             extra_args.push(format!("--prost_opt={}", tonic_or_prost_opt));
563             if is_tonic {
564                 extra_args.push(format!("--tonic_opt={}", tonic_or_prost_opt));
565             }
566         }
567 
568         if protoc.is_none() {
569             return Err(
570                 "No `--protoc` value was found. Unable to parse path to proto compiler."
571                     .to_string(),
572             );
573         }
574         if out_dir.is_none() {
575             return Err(
576                 "No `--prost_out` value was found. Unable to parse output directory.".to_string(),
577             );
578         }
579         if crate_name.is_none() {
580             return Err(
581                 "No `--package_info_output` value was found. Unable to parse target crate name."
582                     .to_string(),
583             );
584         }
585         if package_info_file.is_none() {
586             return Err("No `--package_info_output` value was found. Unable to parse package info output file.".to_string());
587         }
588         if out_librs.is_none() {
589             return Err("No `--out_librs` value was found. Unable to parse the output location for all combined prost outputs.".to_string());
590         }
591         if descriptor_set.is_none() {
592             return Err(
593                 "No `--descriptor_set` value was found. Unable to parse descriptor set path."
594                     .to_string(),
595             );
596         }
597         if label.is_none() {
598             return Err(
599                 "No `--label` value was found. Unable to parse the label of the target crate."
600                     .to_string(),
601             );
602         }
603 
604         Ok(Args {
605             protoc: protoc.unwrap(),
606             out_dir: out_dir.unwrap(),
607             crate_name: crate_name.unwrap(),
608             package_info_file: package_info_file.unwrap(),
609             proto_files,
610             includes,
611             descriptor_set: descriptor_set.unwrap(),
612             out_librs: out_librs.unwrap(),
613             rustfmt,
614             proto_paths,
615             is_tonic,
616             label: label.unwrap(),
617             extra_args,
618         })
619     }
620 }
621 
622 /// Get the output directory with the label suffixed.
get_output_dir(out_dir: &Path, label: &str) -> PathBuf623 fn get_output_dir(out_dir: &Path, label: &str) -> PathBuf {
624     let label_as_path = label
625         .replace('@', "")
626         .replace("//", "_")
627         .replace(['/', ':'], "_");
628     PathBuf::from(format!(
629         "{}/prost-build-{}",
630         out_dir.display(),
631         label_as_path
632     ))
633 }
634 
635 /// Get the output directory with the label suffixed, and create it if it doesn't exist.
636 ///
637 /// This will remove the directory first if it already exists.
get_and_create_output_dir(out_dir: &Path, label: &str) -> PathBuf638 fn get_and_create_output_dir(out_dir: &Path, label: &str) -> PathBuf {
639     let out_dir = get_output_dir(out_dir, label);
640     if out_dir.exists() {
641         fs::remove_dir_all(&out_dir).expect("Failed to remove old output directory");
642     }
643     fs::create_dir_all(&out_dir).expect("Failed to create output directory");
644     out_dir
645 }
646 
647 /// Parse the descriptor set file into a `FileDescriptorSet`.
parse_descriptor_set_file(descriptor_set_path: &PathBuf) -> FileDescriptorSet648 fn parse_descriptor_set_file(descriptor_set_path: &PathBuf) -> FileDescriptorSet {
649     let descriptor_set_bytes =
650         fs::read(descriptor_set_path).expect("Failed to read descriptor set");
651     let descriptor_set = FileDescriptorSet::decode(descriptor_set_bytes.as_slice())
652         .expect("Failed to decode descriptor set");
653 
654     descriptor_set
655 }
656 
657 /// Get the package name from the descriptor set.
get_package_name(descriptor_set: &FileDescriptorSet) -> Option<String>658 fn get_package_name(descriptor_set: &FileDescriptorSet) -> Option<String> {
659     let mut package_name = None;
660 
661     for file in &descriptor_set.file {
662         if let Some(package) = &file.package {
663             package_name = Some(package.clone());
664             break;
665         }
666     }
667 
668     package_name
669 }
670 
671 /// Whether the proto file should expect to generate a .rs file.
672 ///
673 /// If the proto file contains any messages, enums, or services, then it should generate a rust file.
674 /// If the proto file only contains extensions, then it will not generate any rust files.
expect_fs_file_to_be_generated(descriptor_set: &FileDescriptorSet) -> bool675 fn expect_fs_file_to_be_generated(descriptor_set: &FileDescriptorSet) -> bool {
676     let mut expect_rs = false;
677 
678     for file in descriptor_set.file.iter() {
679         let has_messages = !file.message_type.is_empty();
680         let has_enums = !file.enum_type.is_empty();
681         let has_services = !file.service.is_empty();
682         let has_extensions = !file.extension.is_empty();
683 
684         let has_definition = has_messages || has_enums || has_services;
685 
686         if has_definition {
687             return true;
688         } else if !has_definition && !has_extensions {
689             expect_rs = true;
690         }
691     }
692 
693     expect_rs
694 }
695 
696 /// Whether the proto file should expect to generate service definitions.
has_services(descriptor_set: &FileDescriptorSet) -> bool697 fn has_services(descriptor_set: &FileDescriptorSet) -> bool {
698     descriptor_set
699         .file
700         .iter()
701         .any(|file| !file.service.is_empty())
702 }
703 
main()704 fn main() {
705     // Always enable backtraces for the protoc wrapper.
706     env::set_var("RUST_BACKTRACE", "1");
707 
708     let Args {
709         protoc,
710         out_dir,
711         crate_name,
712         label,
713         package_info_file,
714         proto_files,
715         includes,
716         descriptor_set,
717         out_librs,
718         rustfmt,
719         proto_paths,
720         is_tonic,
721         extra_args,
722     } = Args::parse().expect("Failed to parse args");
723 
724     let out_dir = get_and_create_output_dir(&out_dir, &label);
725 
726     let descriptor_set = parse_descriptor_set_file(&descriptor_set);
727     let package_name = get_package_name(&descriptor_set).unwrap_or_default();
728     let expect_rs = expect_fs_file_to_be_generated(&descriptor_set);
729     let has_services = has_services(&descriptor_set);
730 
731     if has_services && !is_tonic {
732         eprintln!("Warning: Service definitions will not be generated because the prost toolchain did not define a tonic plugin.");
733     }
734 
735     let mut cmd = process::Command::new(protoc);
736     cmd.arg(format!("--prost_out={}", out_dir.display()));
737     if is_tonic {
738         cmd.arg(format!("--tonic_out={}", out_dir.display()));
739     }
740     cmd.args(extra_args);
741     cmd.args(
742         proto_paths
743             .iter()
744             .map(|proto_path| format!("--proto_path={}", proto_path)),
745     );
746     cmd.args(includes.iter().map(|include| format!("-I{}", include)));
747     cmd.args(&proto_files);
748 
749     let status = cmd.status().expect("Failed to spawn protoc process");
750     if !status.success() {
751         panic!(
752             "protoc failed with status: {}",
753             status.code().expect("failed to get exit code")
754         );
755     }
756 
757     // Not all proto files will consistently produce `.rs` or `.tonic.rs` files. This is
758     // caused by the proto file being transpiled not having an RPC service or other protos
759     // defined (a natural and expected situation). To guarantee consistent outputs, all
760     // `.rs` files are either renamed to `.tonic.rs` if there is no `.tonic.rs` or prepended
761     // to the existing `.tonic.rs`.
762     if is_tonic {
763         let tonic_files: BTreeSet<PathBuf> = find_generated_rust_files(&out_dir);
764 
765         for tonic_file in tonic_files.iter() {
766             let tonic_path_str = tonic_file.to_str().expect("Failed to convert to str");
767             let filename = tonic_file
768                 .file_name()
769                 .expect("Failed to get file name")
770                 .to_str()
771                 .expect("Failed to convert to str");
772 
773             let is_tonic_file = filename.ends_with(".tonic.rs");
774 
775             if is_tonic_file {
776                 let rs_file_str = format!(
777                     "{}.rs",
778                     tonic_path_str
779                         .strip_suffix(".tonic.rs")
780                         .expect("Failed to strip suffix.")
781                 );
782                 let rs_file = PathBuf::from(&rs_file_str);
783 
784                 if rs_file.exists() {
785                     let rs_content = fs::read_to_string(&rs_file).expect("Failed to read file.");
786                     let tonic_content =
787                         fs::read_to_string(tonic_file).expect("Failed to read file.");
788                     fs::write(tonic_file, format!("{}\n{}", rs_content, tonic_content))
789                         .expect("Failed to write file.");
790                     fs::remove_file(&rs_file).unwrap_or_else(|err| {
791                         panic!("Failed to remove file: {err:?}: {rs_file:?}")
792                     });
793                 }
794             } else {
795                 let real_tonic_file = PathBuf::from(format!(
796                     "{}.tonic.rs",
797                     tonic_path_str
798                         .strip_suffix(".rs")
799                         .expect("Failed to strip suffix.")
800                 ));
801                 if real_tonic_file.exists() {
802                     continue;
803                 }
804                 fs::rename(tonic_file, &real_tonic_file).unwrap_or_else(|err| {
805                     panic!("Failed to rename file: {err:?}: {tonic_file:?} -> {real_tonic_file:?}");
806                 });
807             }
808         }
809     }
810 
811     // Locate all prost-generated outputs.
812     let mut rust_files = find_generated_rust_files(&out_dir);
813     if rust_files.is_empty() {
814         if expect_rs {
815             panic!("No .rs files were generated by prost.");
816         } else {
817             let file_stem = if package_name.is_empty() {
818                 "_"
819             } else {
820                 &package_name
821             };
822             let file_stem = format!("{}{}", file_stem, if is_tonic { ".tonic" } else { "" });
823             let empty_rs_file = out_dir.join(format!("{}.rs", file_stem));
824             fs::write(&empty_rs_file, "").expect("Failed to write file.");
825             rust_files.insert(empty_rs_file);
826         }
827     }
828 
829     let extern_paths = get_extern_paths(&descriptor_set, &crate_name)
830         .expect("Failed to compute proto package info");
831 
832     // Write outputs
833     fs::write(&out_librs, generate_lib_rs(&rust_files, is_tonic)).expect("Failed to write file.");
834     fs::write(
835         package_info_file,
836         extern_paths
837             .into_iter()
838             .map(|(proto_path, rust_path)| format!(".{}=::{}", proto_path, rust_path))
839             .collect::<Vec<_>>()
840             .join("\n"),
841     )
842     .expect("Failed to write file.");
843 
844     // Finally run rustfmt on the output lib.rs file
845     if let Some(rustfmt) = rustfmt {
846         let fmt_status = process::Command::new(rustfmt)
847             .arg("--edition")
848             .arg("2021")
849             .arg("--quiet")
850             .arg(&out_librs)
851             .status()
852             .expect("Failed to spawn rustfmt process");
853         if !fmt_status.success() {
854             panic!(
855                 "rustfmt failed with exit code: {}",
856                 fmt_status.code().expect("Failed to get exit code")
857             );
858         }
859     }
860 }
861 
862 /// Rust built-in keywords and reserved keywords.
863 const RUST_KEYWORDS: [&str; 51] = [
864     "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate",
865     "do", "dyn", "else", "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in",
866     "let", "loop", "macro", "match", "mod", "move", "mut", "override", "priv", "pub", "ref",
867     "return", "self", "Self", "static", "struct", "super", "trait", "true", "try", "type",
868     "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
869 ];
870 
871 /// Returns true if the given string is a Rust keyword.
is_keyword(s: &str) -> bool872 fn is_keyword(s: &str) -> bool {
873     RUST_KEYWORDS.contains(&s)
874 }
875 
876 /// Escapes a Rust keyword by prefixing it with `r#`.
escape_keyword(s: String) -> String877 fn escape_keyword(s: String) -> String {
878     if is_keyword(&s) {
879         return format!("r#{s}");
880     }
881     s
882 }
883 
884 #[cfg(test)]
885 mod test {
886 
887     use super::*;
888 
889     use prost_types::{FieldDescriptorProto, ServiceDescriptorProto};
890 
891     #[test]
oneof_type_to_extern_paths_test()892     fn oneof_type_to_extern_paths_test() {
893         let oneof_descriptor = OneofDescriptorProto {
894             name: Some("Foo".to_string()),
895             ..OneofDescriptorProto::default()
896         };
897 
898         {
899             let mut extern_paths = BTreeMap::new();
900             oneof_type_to_extern_paths(
901                 &mut extern_paths,
902                 &ProtoPath::from("bar"),
903                 &RustModulePath::from("bar"),
904                 &oneof_descriptor,
905             );
906 
907             assert_eq!(extern_paths.len(), 1);
908             assert_eq!(
909                 extern_paths.get(&ProtoPath::from("bar.Foo")),
910                 Some(&RustModulePath::from("bar::Foo"))
911             );
912         }
913 
914         {
915             let mut extern_paths = BTreeMap::new();
916             oneof_type_to_extern_paths(
917                 &mut extern_paths,
918                 &ProtoPath::from("bar.baz"),
919                 &RustModulePath::from("bar::baz"),
920                 &oneof_descriptor,
921             );
922 
923             assert_eq!(extern_paths.len(), 1);
924             assert_eq!(
925                 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
926                 Some(&RustModulePath::from("bar::baz::Foo"))
927             );
928         }
929     }
930 
931     #[test]
enum_type_to_extern_paths_test()932     fn enum_type_to_extern_paths_test() {
933         let enum_descriptor = EnumDescriptorProto {
934             name: Some("Foo".to_string()),
935             ..EnumDescriptorProto::default()
936         };
937 
938         {
939             let mut extern_paths = BTreeMap::new();
940             enum_type_to_extern_paths(
941                 &mut extern_paths,
942                 &ProtoPath::from("bar"),
943                 &RustModulePath::from("bar"),
944                 &enum_descriptor,
945             );
946 
947             assert_eq!(extern_paths.len(), 1);
948             assert_eq!(
949                 extern_paths.get(&ProtoPath::from("bar.Foo")),
950                 Some(&RustModulePath::from("bar::Foo"))
951             );
952         }
953 
954         {
955             let mut extern_paths = BTreeMap::new();
956             enum_type_to_extern_paths(
957                 &mut extern_paths,
958                 &ProtoPath::from("bar.baz"),
959                 &RustModulePath::from("bar::baz"),
960                 &enum_descriptor,
961             );
962 
963             assert_eq!(extern_paths.len(), 1);
964             assert_eq!(
965                 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
966                 Some(&RustModulePath::from("bar::baz::Foo"))
967             );
968         }
969     }
970 
971     #[test]
message_type_to_extern_paths_test()972     fn message_type_to_extern_paths_test() {
973         let message_descriptor = DescriptorProto {
974             name: Some("Foo".to_string()),
975             nested_type: vec![
976                 DescriptorProto {
977                     name: Some("Bar".to_string()),
978                     ..DescriptorProto::default()
979                 },
980                 DescriptorProto {
981                     name: Some("Nested".to_string()),
982                     nested_type: vec![DescriptorProto {
983                         name: Some("Baz".to_string()),
984                         enum_type: vec![EnumDescriptorProto {
985                             name: Some("Chuck".to_string()),
986                             ..EnumDescriptorProto::default()
987                         }],
988                         ..DescriptorProto::default()
989                     }],
990                     ..DescriptorProto::default()
991                 },
992             ],
993             enum_type: vec![EnumDescriptorProto {
994                 name: Some("Qux".to_string()),
995                 ..EnumDescriptorProto::default()
996             }],
997             ..DescriptorProto::default()
998         };
999 
1000         {
1001             let mut extern_paths = BTreeMap::new();
1002             message_type_to_extern_paths(
1003                 &mut extern_paths,
1004                 &ProtoPath::from("bar"),
1005                 &RustModulePath::from("bar"),
1006                 &message_descriptor,
1007             );
1008             assert_eq!(extern_paths.len(), 6);
1009             assert_eq!(
1010                 extern_paths.get(&ProtoPath::from("bar.Foo")),
1011                 Some(&RustModulePath::from("bar::Foo"))
1012             );
1013             assert_eq!(
1014                 extern_paths.get(&ProtoPath::from("bar.foo.Bar")),
1015                 Some(&RustModulePath::from("bar::foo::Bar"))
1016             );
1017             assert_eq!(
1018                 extern_paths.get(&ProtoPath::from("bar.foo.Nested")),
1019                 Some(&RustModulePath::from("bar::foo::Nested"))
1020             );
1021             assert_eq!(
1022                 extern_paths.get(&ProtoPath::from("bar.foo.nested.Baz")),
1023                 Some(&RustModulePath::from("bar::foo::nested::Baz"))
1024             );
1025         }
1026 
1027         {
1028             let mut extern_paths = BTreeMap::new();
1029             message_type_to_extern_paths(
1030                 &mut extern_paths,
1031                 &ProtoPath::from("bar.bob"),
1032                 &RustModulePath::from("bar::bob"),
1033                 &message_descriptor,
1034             );
1035             assert_eq!(extern_paths.len(), 6);
1036             assert_eq!(
1037                 extern_paths.get(&ProtoPath::from("bar.bob.Foo")),
1038                 Some(&RustModulePath::from("bar::bob::Foo"))
1039             );
1040             assert_eq!(
1041                 extern_paths.get(&ProtoPath::from("bar.bob.foo.Bar")),
1042                 Some(&RustModulePath::from("bar::bob::foo::Bar"))
1043             );
1044             assert_eq!(
1045                 extern_paths.get(&ProtoPath::from("bar.bob.foo.Nested")),
1046                 Some(&RustModulePath::from("bar::bob::foo::Nested"))
1047             );
1048             assert_eq!(
1049                 extern_paths.get(&ProtoPath::from("bar.bob.foo.nested.Baz")),
1050                 Some(&RustModulePath::from("bar::bob::foo::nested::Baz"))
1051             );
1052         }
1053     }
1054 
1055     #[test]
proto_path_test()1056     fn proto_path_test() {
1057         {
1058             let proto_path = ProtoPath::from("");
1059             assert_eq!(proto_path.to_string(), "");
1060             assert_eq!(proto_path.join("foo"), ProtoPath::from("foo"));
1061         }
1062         {
1063             let proto_path = ProtoPath::from("foo");
1064             assert_eq!(proto_path.to_string(), "foo");
1065             assert_eq!(proto_path.join(""), ProtoPath::from("foo"));
1066         }
1067         {
1068             let proto_path = ProtoPath::from("foo");
1069             assert_eq!(proto_path.to_string(), "foo");
1070             assert_eq!(proto_path.join("bar"), ProtoPath::from("foo.bar"));
1071         }
1072         {
1073             let proto_path = ProtoPath::from("foo.bar");
1074             assert_eq!(proto_path.to_string(), "foo.bar");
1075             assert_eq!(proto_path.join("baz"), ProtoPath::from("foo.bar.baz"));
1076         }
1077         {
1078             let proto_path = ProtoPath::from("Foo.baR");
1079             assert_eq!(proto_path.to_string(), "Foo.baR");
1080             assert_eq!(proto_path.join("baz"), ProtoPath::from("Foo.baR.baz"));
1081         }
1082     }
1083 
1084     #[test]
rust_module_path_test()1085     fn rust_module_path_test() {
1086         {
1087             let rust_module_path = RustModulePath::from("");
1088             assert_eq!(rust_module_path.to_string(), "");
1089             assert_eq!(rust_module_path.join("foo"), RustModulePath::from("foo"));
1090         }
1091         {
1092             let rust_module_path = RustModulePath::from("foo");
1093             assert_eq!(rust_module_path.to_string(), "foo");
1094             assert_eq!(rust_module_path.join(""), RustModulePath::from("foo"));
1095         }
1096         {
1097             let rust_module_path = RustModulePath::from("foo");
1098             assert_eq!(rust_module_path.to_string(), "foo");
1099             assert_eq!(
1100                 rust_module_path.join("bar"),
1101                 RustModulePath::from("foo::bar")
1102             );
1103         }
1104         {
1105             let rust_module_path = RustModulePath::from("foo::bar");
1106             assert_eq!(rust_module_path.to_string(), "foo::bar");
1107             assert_eq!(
1108                 rust_module_path.join("baz"),
1109                 RustModulePath::from("foo::bar::baz")
1110             );
1111         }
1112     }
1113 
1114     #[test]
expect_fs_file_to_be_generated_test()1115     fn expect_fs_file_to_be_generated_test() {
1116         {
1117             // Empty descriptor set should create a file.
1118             let descriptor_set = FileDescriptorSet {
1119                 file: vec![FileDescriptorProto {
1120                     name: Some("foo.proto".to_string()),
1121                     ..FileDescriptorProto::default()
1122                 }],
1123             };
1124             assert!(expect_fs_file_to_be_generated(&descriptor_set));
1125         }
1126         {
1127             // Descriptor set with only message should create a file.
1128             let descriptor_set = FileDescriptorSet {
1129                 file: vec![FileDescriptorProto {
1130                     name: Some("foo.proto".to_string()),
1131                     message_type: vec![DescriptorProto {
1132                         name: Some("Foo".to_string()),
1133                         ..DescriptorProto::default()
1134                     }],
1135                     ..FileDescriptorProto::default()
1136                 }],
1137             };
1138             assert!(expect_fs_file_to_be_generated(&descriptor_set));
1139         }
1140         {
1141             // Descriptor set with only enum should create a file.
1142             let descriptor_set = FileDescriptorSet {
1143                 file: vec![FileDescriptorProto {
1144                     name: Some("foo.proto".to_string()),
1145                     enum_type: vec![EnumDescriptorProto {
1146                         name: Some("Foo".to_string()),
1147                         ..EnumDescriptorProto::default()
1148                     }],
1149                     ..FileDescriptorProto::default()
1150                 }],
1151             };
1152             assert!(expect_fs_file_to_be_generated(&descriptor_set));
1153         }
1154         {
1155             // Descriptor set with only service should create a file.
1156             let descriptor_set = FileDescriptorSet {
1157                 file: vec![FileDescriptorProto {
1158                     name: Some("foo.proto".to_string()),
1159                     service: vec![ServiceDescriptorProto {
1160                         name: Some("Foo".to_string()),
1161                         ..ServiceDescriptorProto::default()
1162                     }],
1163                     ..FileDescriptorProto::default()
1164                 }],
1165             };
1166             assert!(expect_fs_file_to_be_generated(&descriptor_set));
1167         }
1168         {
1169             // Descriptor set with only extensions should not create a file.
1170             let descriptor_set = FileDescriptorSet {
1171                 file: vec![FileDescriptorProto {
1172                     name: Some("foo.proto".to_string()),
1173                     extension: vec![FieldDescriptorProto {
1174                         name: Some("Foo".to_string()),
1175                         ..FieldDescriptorProto::default()
1176                     }],
1177                     ..FileDescriptorProto::default()
1178                 }],
1179             };
1180             assert!(!expect_fs_file_to_be_generated(&descriptor_set));
1181         }
1182     }
1183 
1184     #[test]
has_services_test()1185     fn has_services_test() {
1186         {
1187             // Empty file should not have services.
1188             let descriptor_set = FileDescriptorSet {
1189                 file: vec![FileDescriptorProto {
1190                     name: Some("foo.proto".to_string()),
1191                     ..FileDescriptorProto::default()
1192                 }],
1193             };
1194             assert!(!has_services(&descriptor_set));
1195         }
1196         {
1197             // File with only message should not have services.
1198             let descriptor_set = FileDescriptorSet {
1199                 file: vec![FileDescriptorProto {
1200                     name: Some("foo.proto".to_string()),
1201                     message_type: vec![DescriptorProto {
1202                         name: Some("Foo".to_string()),
1203                         ..DescriptorProto::default()
1204                     }],
1205                     ..FileDescriptorProto::default()
1206                 }],
1207             };
1208             assert!(!has_services(&descriptor_set));
1209         }
1210         {
1211             // File with services should have services.
1212             let descriptor_set = FileDescriptorSet {
1213                 file: vec![FileDescriptorProto {
1214                     name: Some("foo.proto".to_string()),
1215                     service: vec![ServiceDescriptorProto {
1216                         name: Some("Foo".to_string()),
1217                         ..ServiceDescriptorProto::default()
1218                     }],
1219                     ..FileDescriptorProto::default()
1220                 }],
1221             };
1222             assert!(has_services(&descriptor_set));
1223         }
1224     }
1225 
1226     #[test]
get_package_name_test()1227     fn get_package_name_test() {
1228         let descriptor_set = FileDescriptorSet {
1229             file: vec![FileDescriptorProto {
1230                 name: Some("foo.proto".to_string()),
1231                 package: Some("foo".to_string()),
1232                 ..FileDescriptorProto::default()
1233             }],
1234         };
1235 
1236         assert_eq!(get_package_name(&descriptor_set), Some("foo".to_string()));
1237     }
1238 
1239     #[test]
is_keyword_test()1240     fn is_keyword_test() {
1241         let non_keywords = [
1242             "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1243             "plugh", "xyzzy", "thud",
1244         ];
1245         for non_keyword in &non_keywords {
1246             assert!(!is_keyword(non_keyword));
1247         }
1248 
1249         for keyword in &RUST_KEYWORDS {
1250             assert!(is_keyword(keyword));
1251         }
1252     }
1253 
1254     #[test]
escape_keyword_test()1255     fn escape_keyword_test() {
1256         let non_keywords = [
1257             "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1258             "plugh", "xyzzy", "thud",
1259         ];
1260         for non_keyword in &non_keywords {
1261             assert_eq!(
1262                 escape_keyword(non_keyword.to_string()),
1263                 non_keyword.to_owned()
1264             );
1265         }
1266 
1267         for keyword in &RUST_KEYWORDS {
1268             assert_eq!(
1269                 escape_keyword(keyword.to_string()),
1270                 format!("r#{}", keyword)
1271             );
1272         }
1273     }
1274 }
1275