1 //! A process wrapper for running a Protobuf compiler configured for Prost or Tonic output in a Bazel rule.
2
3 use std::collections::BTreeMap;
4 use std::collections::BTreeSet;
5 use std::fmt::{Display, Formatter, Write};
6 use std::fs;
7 use std::io::BufRead;
8 use std::path::Path;
9 use std::path::PathBuf;
10 use std::process;
11 use std::{env, fmt};
12
13 use heck::{ToSnakeCase, ToUpperCamelCase};
14 use prost::Message;
15 use prost_types::{
16 DescriptorProto, EnumDescriptorProto, FileDescriptorProto, FileDescriptorSet,
17 OneofDescriptorProto,
18 };
19
20 /// Locate prost outputs in the protoc output directory.
find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf>21 fn find_generated_rust_files(out_dir: &Path) -> BTreeSet<PathBuf> {
22 let mut all_rs_files: BTreeSet<PathBuf> = BTreeSet::new();
23 for entry in fs::read_dir(out_dir).expect("Failed to read directory") {
24 let entry = entry.expect("Failed to read entry");
25 let path = entry.path();
26 if path.is_dir() {
27 for f in find_generated_rust_files(&path) {
28 all_rs_files.insert(f);
29 }
30 } else if let Some(ext) = path.extension() {
31 if ext == "rs" {
32 all_rs_files.insert(path);
33 }
34 } else if let Some(name) = path.file_name() {
35 // The filename is set to `_` when the package name is empty.
36 if name == "_" {
37 let rs_name = path.parent().expect("Failed to get parent").join("_.rs");
38 fs::rename(&path, &rs_name).unwrap_or_else(|err| {
39 panic!("Failed to rename file: {err:?}: {path:?} -> {rs_name:?}")
40 });
41 all_rs_files.insert(rs_name);
42 }
43 }
44 }
45
46 all_rs_files
47 }
48
snake_cased_package_name(package: &str) -> String49 fn snake_cased_package_name(package: &str) -> String {
50 if package == "_" {
51 return package.to_owned();
52 }
53
54 package
55 .split('.')
56 .map(|s| s.to_snake_case())
57 .collect::<Vec<_>>()
58 .join(".")
59 }
60
61 /// Rust module definition.
62 #[derive(Debug, Default)]
63 struct Module {
64 /// The name of the module.
65 name: String,
66
67 /// The contents of the module.
68 contents: String,
69
70 /// The names of any other modules which are submodules of this module.
71 submodules: BTreeSet<String>,
72 }
73
74 /// Generate a lib.rs file with all prost/tonic outputs embeeded in modules which
75 /// mirror the proto packages. For the example proto file we would expect to see
76 /// the Rust output that follows it.
77 ///
78 /// ```proto
79 /// syntax = "proto3";
80 /// package examples.prost.helloworld;
81 ///
82 /// message HelloRequest {
83 /// // Request message contains the name to be greeted
84 /// string name = 1;
85 /// }
86 //
87 /// message HelloReply {
88 /// // Reply contains the greeting message
89 /// string message = 1;
90 /// }
91 /// ```
92 ///
93 /// This is expected to render out to something like the following. Note that
94 /// formatting is not applied so indentation may be missing in the actual output.
95 ///
96 /// ```ignore
97 /// pub mod examples {
98 /// pub mod prost {
99 /// pub mod helloworld {
100 /// // @generated
101 /// #[allow(clippy::derive_partial_eq_without_eq)]
102 /// #[derive(Clone, PartialEq, ::prost::Message)]
103 /// pub struct HelloRequest {
104 /// /// Request message contains the name to be greeted
105 /// #[prost(string, tag = "1")]
106 /// pub name: ::prost::alloc::string::String,
107 /// }
108 /// #[allow(clippy::derive_partial_eq_without_eq)]
109 /// #[derive(Clone, PartialEq, ::prost::Message)]
110 /// pub struct HelloReply {
111 /// /// Reply contains the greeting message
112 /// #[prost(string, tag = "1")]
113 /// pub message: ::prost::alloc::string::String,
114 /// }
115 /// // @protoc_insertion_point(module)
116 /// }
117 /// }
118 /// }
119 /// ```
generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String120 fn generate_lib_rs(prost_outputs: &BTreeSet<PathBuf>, is_tonic: bool) -> String {
121 let mut module_info = BTreeMap::new();
122
123 for path in prost_outputs.iter() {
124 let mut package = path
125 .file_stem()
126 .expect("Failed to get file stem")
127 .to_str()
128 .expect("Failed to convert to str")
129 .to_string();
130
131 if is_tonic {
132 package = package
133 .strip_suffix(".tonic")
134 .expect("Failed to strip suffix")
135 .to_string()
136 };
137
138 if package.is_empty() {
139 continue;
140 }
141
142 let name = if package == "_" {
143 package.clone()
144 } else if package.contains('.') {
145 package
146 .rsplit_once('.')
147 .expect("Failed to split on '.'")
148 .1
149 .to_snake_case()
150 .to_string()
151 } else {
152 package.to_snake_case()
153 };
154
155 // Avoid a stack overflow by skipping a known bad package name
156 let module_name = snake_cased_package_name(&package);
157
158 module_info.insert(
159 module_name.clone(),
160 Module {
161 name,
162 contents: fs::read_to_string(path).expect("Failed to read file"),
163 submodules: BTreeSet::new(),
164 },
165 );
166
167 let module_parts = module_name.split('.').collect::<Vec<_>>();
168 for parent_module_index in 0..module_parts.len() {
169 let child_module_index = parent_module_index + 1;
170 if child_module_index >= module_parts.len() {
171 break;
172 }
173 let full_parent_module_name = module_parts[0..parent_module_index + 1].join(".");
174 let parent_module_name = module_parts[parent_module_index];
175 let child_module_name = module_parts[child_module_index];
176
177 module_info
178 .entry(full_parent_module_name.clone())
179 .and_modify(|parent_module| {
180 parent_module
181 .submodules
182 .insert(child_module_name.to_string());
183 })
184 .or_insert(Module {
185 name: parent_module_name.to_string(),
186 contents: "".to_string(),
187 submodules: [child_module_name.to_string()].iter().cloned().collect(),
188 });
189 }
190 }
191
192 let mut content = "// @generated\n\n".to_string();
193 write_module(&mut content, &module_info, "", 0);
194 content
195 }
196
197 /// Write out a rust module and all of its submodules.
write_module( content: &mut String, module_info: &BTreeMap<String, Module>, module_name: &str, depth: usize, )198 fn write_module(
199 content: &mut String,
200 module_info: &BTreeMap<String, Module>,
201 module_name: &str,
202 depth: usize,
203 ) {
204 if module_name.is_empty() {
205 for submodule_name in module_info.keys() {
206 write_module(content, module_info, submodule_name, depth + 1);
207 }
208 return;
209 }
210 let module = module_info.get(module_name).expect("Failed to get module");
211 let indent = " ".repeat(depth);
212 let is_rust_module = module.name != "_";
213
214 if is_rust_module {
215 let rust_module_name = escape_keyword(module.name.clone());
216 content
217 .write_str(&format!("{}pub mod {} {{\n", indent, rust_module_name))
218 .expect("Failed to write string");
219 }
220
221 content
222 .write_str(&module.contents)
223 .expect("Failed to write string");
224
225 for submodule_name in module.submodules.iter() {
226 write_module(
227 content,
228 module_info,
229 [module_name, submodule_name].join(".").as_str(),
230 depth + 1,
231 );
232 }
233
234 if is_rust_module {
235 content
236 .write_str(&format!("{}}}\n", indent))
237 .expect("Failed to write string");
238 }
239 }
240
241 /// ProtoPath is a path to a proto message, enum, or oneof.
242 ///
243 /// Example: `helloworld.Greeter.HelloRequest`
244 #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
245 struct ProtoPath(String);
246
247 impl ProtoPath {
248 /// Join a component to the end of the path.
join(&self, component: &str) -> ProtoPath249 fn join(&self, component: &str) -> ProtoPath {
250 if self.0.is_empty() {
251 return ProtoPath(component.to_string());
252 }
253 if component.is_empty() {
254 return self.clone();
255 }
256
257 ProtoPath(format!("{}.{}", self.0, component))
258 }
259 }
260
261 impl Display for ProtoPath {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result262 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
263 write!(f, "{}", self.0)
264 }
265 }
266
267 impl From<&str> for ProtoPath {
from(path: &str) -> Self268 fn from(path: &str) -> Self {
269 ProtoPath(path.to_string())
270 }
271 }
272
273 /// RustModulePath is a path to a rust module.
274 ///
275 /// Example: `helloworld::greeter::HelloRequest`
276 #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)]
277 struct RustModulePath(String);
278
279 impl RustModulePath {
280 /// Join a path to the end of the module path.
join(&self, path: &str) -> RustModulePath281 fn join(&self, path: &str) -> RustModulePath {
282 if self.0.is_empty() {
283 return RustModulePath(path.to_string());
284 }
285 if path.is_empty() {
286 return self.clone();
287 }
288
289 RustModulePath(format!("{}::{}", self.0, path))
290 }
291 }
292
293 impl Display for RustModulePath {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result294 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
295 write!(f, "{}", self.0)
296 }
297 }
298
299 impl From<&str> for RustModulePath {
from(path: &str) -> Self300 fn from(path: &str) -> Self {
301 RustModulePath(path.to_string())
302 }
303 }
304
305 /// Compute the `--extern_path` flags for a list of proto files. This is
306 /// expected to convert proto files into a BTreeMap of
307 /// `example.prost.helloworld`: `crate_name::example::prost::helloworld`.
get_extern_paths( descriptor_set: &FileDescriptorSet, crate_name: &str, ) -> Result<BTreeMap<ProtoPath, RustModulePath>, String>308 fn get_extern_paths(
309 descriptor_set: &FileDescriptorSet,
310 crate_name: &str,
311 ) -> Result<BTreeMap<ProtoPath, RustModulePath>, String> {
312 let mut extern_paths = BTreeMap::new();
313 let rust_path = RustModulePath(crate_name.to_string());
314
315 for file in descriptor_set.file.iter() {
316 descriptor_set_file_to_extern_paths(&mut extern_paths, &rust_path, file);
317 }
318
319 Ok(extern_paths)
320 }
321
322 /// Add the extern_path pairs for a file descriptor type.
descriptor_set_file_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, rust_path: &RustModulePath, file: &FileDescriptorProto, )323 fn descriptor_set_file_to_extern_paths(
324 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
325 rust_path: &RustModulePath,
326 file: &FileDescriptorProto,
327 ) {
328 let package = file.package.clone().unwrap_or_default();
329 let rust_path = rust_path.join(&snake_cased_package_name(&package).replace('.', "::"));
330 let proto_path = ProtoPath(package);
331
332 for message_type in file.message_type.iter() {
333 message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, message_type);
334 }
335
336 for enum_type in file.enum_type.iter() {
337 enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
338 }
339 }
340
341 /// Add the extern_path pairs for a message descriptor type.
message_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, message_type: &DescriptorProto, )342 fn message_type_to_extern_paths(
343 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
344 proto_path: &ProtoPath,
345 rust_path: &RustModulePath,
346 message_type: &DescriptorProto,
347 ) {
348 let message_type_name = message_type
349 .name
350 .as_ref()
351 .expect("Failed to get message type name");
352
353 extern_paths.insert(
354 proto_path.join(message_type_name),
355 rust_path.join(&message_type_name.to_upper_camel_case()),
356 );
357
358 let name_lower = message_type_name.to_lowercase();
359 let proto_path = proto_path.join(&name_lower);
360 let rust_path = rust_path.join(&name_lower);
361
362 for nested_type in message_type.nested_type.iter() {
363 message_type_to_extern_paths(extern_paths, &proto_path, &rust_path, nested_type)
364 }
365
366 for enum_type in message_type.enum_type.iter() {
367 enum_type_to_extern_paths(extern_paths, &proto_path, &rust_path, enum_type);
368 }
369
370 for oneof_type in message_type.oneof_decl.iter() {
371 oneof_type_to_extern_paths(extern_paths, &proto_path, &rust_path, oneof_type);
372 }
373 }
374
375 /// Add the extern_path pairs for an enum type.
enum_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, enum_type: &EnumDescriptorProto, )376 fn enum_type_to_extern_paths(
377 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
378 proto_path: &ProtoPath,
379 rust_path: &RustModulePath,
380 enum_type: &EnumDescriptorProto,
381 ) {
382 let enum_type_name = enum_type
383 .name
384 .as_ref()
385 .expect("Failed to get enum type name");
386 extern_paths.insert(
387 proto_path.join(enum_type_name),
388 rust_path.join(enum_type_name),
389 );
390 }
391
oneof_type_to_extern_paths( extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>, proto_path: &ProtoPath, rust_path: &RustModulePath, oneof_type: &OneofDescriptorProto, )392 fn oneof_type_to_extern_paths(
393 extern_paths: &mut BTreeMap<ProtoPath, RustModulePath>,
394 proto_path: &ProtoPath,
395 rust_path: &RustModulePath,
396 oneof_type: &OneofDescriptorProto,
397 ) {
398 let oneof_type_name = oneof_type
399 .name
400 .as_ref()
401 .expect("Failed to get oneof type name");
402 extern_paths.insert(
403 proto_path.join(oneof_type_name),
404 rust_path.join(oneof_type_name),
405 );
406 }
407
408 /// The parsed command-line arguments.
409 struct Args {
410 /// The path to the protoc binary.
411 protoc: PathBuf,
412
413 /// The path to the output directory.
414 out_dir: PathBuf,
415
416 /// The name of the crate.
417 crate_name: String,
418
419 /// The bazel label.
420 label: String,
421
422 /// The path to the package info file.
423 package_info_file: PathBuf,
424
425 /// The proto files to compile.
426 proto_files: Vec<PathBuf>,
427
428 /// The include directories.
429 includes: Vec<String>,
430
431 /// Dependency descriptor sets.
432 descriptor_set: PathBuf,
433
434 /// The path to the generated lib.rs file.
435 out_librs: PathBuf,
436
437 /// The proto include paths.
438 proto_paths: Vec<String>,
439
440 /// The path to the rustfmt binary.
441 rustfmt: Option<PathBuf>,
442
443 /// Whether to generate tonic code.
444 is_tonic: bool,
445
446 /// Extra arguments to pass to protoc.
447 extra_args: Vec<String>,
448 }
449
450 impl Args {
451 /// Parse the command-line arguments.
parse() -> Result<Args, String>452 fn parse() -> Result<Args, String> {
453 let mut protoc: Option<PathBuf> = None;
454 let mut out_dir: Option<PathBuf> = None;
455 let mut crate_name: Option<String> = None;
456 let mut package_info_file: Option<PathBuf> = None;
457 let mut proto_files: Vec<PathBuf> = Vec::new();
458 let mut includes = Vec::new();
459 let mut descriptor_set = None;
460 let mut out_librs: Option<PathBuf> = None;
461 let mut rustfmt: Option<PathBuf> = None;
462 let mut proto_paths = Vec::new();
463 let mut label: Option<String> = None;
464 let mut tonic_or_prost_opts = Vec::new();
465 let mut is_tonic = false;
466
467 let mut extra_args = Vec::new();
468
469 let mut handle_arg = |arg: String| {
470 if !arg.starts_with('-') {
471 proto_files.push(PathBuf::from(arg));
472 return;
473 }
474
475 if arg.starts_with("-I") {
476 includes.push(
477 arg.strip_prefix("-I")
478 .expect("Failed to strip -I")
479 .to_string(),
480 );
481 return;
482 }
483
484 if arg == "--is_tonic" {
485 is_tonic = true;
486 return;
487 }
488
489 if !arg.contains('=') {
490 extra_args.push(arg);
491 return;
492 }
493
494 let parts = arg.split_once('=').expect("Failed to split argument on =");
495 match parts {
496 ("--protoc", value) => {
497 protoc = Some(PathBuf::from(value));
498 }
499 ("--prost_out", value) => {
500 out_dir = Some(PathBuf::from(value));
501 }
502 ("--package_info_output", value) => {
503 let (key, value) = value
504 .split_once('=')
505 .map(|(a, b)| (a.to_string(), PathBuf::from(b)))
506 .expect("Failed to parse package info output");
507 crate_name = Some(key);
508 package_info_file = Some(value);
509 }
510 ("--deps_info", value) => {
511 for line in fs::read_to_string(value)
512 .expect("Failed to read file")
513 .lines()
514 {
515 let path = PathBuf::from(line.trim());
516 for flag in fs::read_to_string(path)
517 .expect("Failed to read file")
518 .lines()
519 {
520 tonic_or_prost_opts.push(format!("extern_path={}", flag.trim()));
521 }
522 }
523 }
524 ("--descriptor_set", value) => {
525 descriptor_set = Some(PathBuf::from(value));
526 }
527 ("--out_librs", value) => {
528 out_librs = Some(PathBuf::from(value));
529 }
530 ("--rustfmt", value) => {
531 rustfmt = Some(PathBuf::from(value));
532 }
533 ("--proto_path", value) => {
534 proto_paths.push(value.to_string());
535 }
536 ("--label", value) => {
537 label = Some(value.to_string());
538 }
539 (arg, value) => {
540 extra_args.push(format!("{}={}", arg, value));
541 }
542 }
543 };
544
545 // Iterate over the given command line arguments parsing out arguments
546 // for the process runner and arguments for protoc and potentially spawn
547 // additional arguments needed by prost.
548 for arg in env::args().skip(1) {
549 if let Some(path) = arg.strip_prefix('@') {
550 // handle argfile
551 let file = std::fs::File::open(path)
552 .map_err(|_| format!("could not open argfile: {}", arg))?;
553 for line in std::io::BufReader::new(file).lines() {
554 handle_arg(line.map_err(|_| format!("could not read argfile: {}", arg))?);
555 }
556 } else {
557 handle_arg(arg);
558 }
559 }
560
561 for tonic_or_prost_opt in tonic_or_prost_opts {
562 extra_args.push(format!("--prost_opt={}", tonic_or_prost_opt));
563 if is_tonic {
564 extra_args.push(format!("--tonic_opt={}", tonic_or_prost_opt));
565 }
566 }
567
568 if protoc.is_none() {
569 return Err(
570 "No `--protoc` value was found. Unable to parse path to proto compiler."
571 .to_string(),
572 );
573 }
574 if out_dir.is_none() {
575 return Err(
576 "No `--prost_out` value was found. Unable to parse output directory.".to_string(),
577 );
578 }
579 if crate_name.is_none() {
580 return Err(
581 "No `--package_info_output` value was found. Unable to parse target crate name."
582 .to_string(),
583 );
584 }
585 if package_info_file.is_none() {
586 return Err("No `--package_info_output` value was found. Unable to parse package info output file.".to_string());
587 }
588 if out_librs.is_none() {
589 return Err("No `--out_librs` value was found. Unable to parse the output location for all combined prost outputs.".to_string());
590 }
591 if descriptor_set.is_none() {
592 return Err(
593 "No `--descriptor_set` value was found. Unable to parse descriptor set path."
594 .to_string(),
595 );
596 }
597 if label.is_none() {
598 return Err(
599 "No `--label` value was found. Unable to parse the label of the target crate."
600 .to_string(),
601 );
602 }
603
604 Ok(Args {
605 protoc: protoc.unwrap(),
606 out_dir: out_dir.unwrap(),
607 crate_name: crate_name.unwrap(),
608 package_info_file: package_info_file.unwrap(),
609 proto_files,
610 includes,
611 descriptor_set: descriptor_set.unwrap(),
612 out_librs: out_librs.unwrap(),
613 rustfmt,
614 proto_paths,
615 is_tonic,
616 label: label.unwrap(),
617 extra_args,
618 })
619 }
620 }
621
622 /// Get the output directory with the label suffixed.
get_output_dir(out_dir: &Path, label: &str) -> PathBuf623 fn get_output_dir(out_dir: &Path, label: &str) -> PathBuf {
624 let label_as_path = label
625 .replace('@', "")
626 .replace("//", "_")
627 .replace(['/', ':'], "_");
628 PathBuf::from(format!(
629 "{}/prost-build-{}",
630 out_dir.display(),
631 label_as_path
632 ))
633 }
634
635 /// Get the output directory with the label suffixed, and create it if it doesn't exist.
636 ///
637 /// This will remove the directory first if it already exists.
get_and_create_output_dir(out_dir: &Path, label: &str) -> PathBuf638 fn get_and_create_output_dir(out_dir: &Path, label: &str) -> PathBuf {
639 let out_dir = get_output_dir(out_dir, label);
640 if out_dir.exists() {
641 fs::remove_dir_all(&out_dir).expect("Failed to remove old output directory");
642 }
643 fs::create_dir_all(&out_dir).expect("Failed to create output directory");
644 out_dir
645 }
646
647 /// Parse the descriptor set file into a `FileDescriptorSet`.
parse_descriptor_set_file(descriptor_set_path: &PathBuf) -> FileDescriptorSet648 fn parse_descriptor_set_file(descriptor_set_path: &PathBuf) -> FileDescriptorSet {
649 let descriptor_set_bytes =
650 fs::read(descriptor_set_path).expect("Failed to read descriptor set");
651 let descriptor_set = FileDescriptorSet::decode(descriptor_set_bytes.as_slice())
652 .expect("Failed to decode descriptor set");
653
654 descriptor_set
655 }
656
657 /// Get the package name from the descriptor set.
get_package_name(descriptor_set: &FileDescriptorSet) -> Option<String>658 fn get_package_name(descriptor_set: &FileDescriptorSet) -> Option<String> {
659 let mut package_name = None;
660
661 for file in &descriptor_set.file {
662 if let Some(package) = &file.package {
663 package_name = Some(package.clone());
664 break;
665 }
666 }
667
668 package_name
669 }
670
671 /// Whether the proto file should expect to generate a .rs file.
672 ///
673 /// If the proto file contains any messages, enums, or services, then it should generate a rust file.
674 /// If the proto file only contains extensions, then it will not generate any rust files.
expect_fs_file_to_be_generated(descriptor_set: &FileDescriptorSet) -> bool675 fn expect_fs_file_to_be_generated(descriptor_set: &FileDescriptorSet) -> bool {
676 let mut expect_rs = false;
677
678 for file in descriptor_set.file.iter() {
679 let has_messages = !file.message_type.is_empty();
680 let has_enums = !file.enum_type.is_empty();
681 let has_services = !file.service.is_empty();
682 let has_extensions = !file.extension.is_empty();
683
684 let has_definition = has_messages || has_enums || has_services;
685
686 if has_definition {
687 return true;
688 } else if !has_definition && !has_extensions {
689 expect_rs = true;
690 }
691 }
692
693 expect_rs
694 }
695
696 /// Whether the proto file should expect to generate service definitions.
has_services(descriptor_set: &FileDescriptorSet) -> bool697 fn has_services(descriptor_set: &FileDescriptorSet) -> bool {
698 descriptor_set
699 .file
700 .iter()
701 .any(|file| !file.service.is_empty())
702 }
703
main()704 fn main() {
705 // Always enable backtraces for the protoc wrapper.
706 env::set_var("RUST_BACKTRACE", "1");
707
708 let Args {
709 protoc,
710 out_dir,
711 crate_name,
712 label,
713 package_info_file,
714 proto_files,
715 includes,
716 descriptor_set,
717 out_librs,
718 rustfmt,
719 proto_paths,
720 is_tonic,
721 extra_args,
722 } = Args::parse().expect("Failed to parse args");
723
724 let out_dir = get_and_create_output_dir(&out_dir, &label);
725
726 let descriptor_set = parse_descriptor_set_file(&descriptor_set);
727 let package_name = get_package_name(&descriptor_set).unwrap_or_default();
728 let expect_rs = expect_fs_file_to_be_generated(&descriptor_set);
729 let has_services = has_services(&descriptor_set);
730
731 if has_services && !is_tonic {
732 eprintln!("Warning: Service definitions will not be generated because the prost toolchain did not define a tonic plugin.");
733 }
734
735 let mut cmd = process::Command::new(protoc);
736 cmd.arg(format!("--prost_out={}", out_dir.display()));
737 if is_tonic {
738 cmd.arg(format!("--tonic_out={}", out_dir.display()));
739 }
740 cmd.args(extra_args);
741 cmd.args(
742 proto_paths
743 .iter()
744 .map(|proto_path| format!("--proto_path={}", proto_path)),
745 );
746 cmd.args(includes.iter().map(|include| format!("-I{}", include)));
747 cmd.args(&proto_files);
748
749 let status = cmd.status().expect("Failed to spawn protoc process");
750 if !status.success() {
751 panic!(
752 "protoc failed with status: {}",
753 status.code().expect("failed to get exit code")
754 );
755 }
756
757 // Not all proto files will consistently produce `.rs` or `.tonic.rs` files. This is
758 // caused by the proto file being transpiled not having an RPC service or other protos
759 // defined (a natural and expected situation). To guarantee consistent outputs, all
760 // `.rs` files are either renamed to `.tonic.rs` if there is no `.tonic.rs` or prepended
761 // to the existing `.tonic.rs`.
762 if is_tonic {
763 let tonic_files: BTreeSet<PathBuf> = find_generated_rust_files(&out_dir);
764
765 for tonic_file in tonic_files.iter() {
766 let tonic_path_str = tonic_file.to_str().expect("Failed to convert to str");
767 let filename = tonic_file
768 .file_name()
769 .expect("Failed to get file name")
770 .to_str()
771 .expect("Failed to convert to str");
772
773 let is_tonic_file = filename.ends_with(".tonic.rs");
774
775 if is_tonic_file {
776 let rs_file_str = format!(
777 "{}.rs",
778 tonic_path_str
779 .strip_suffix(".tonic.rs")
780 .expect("Failed to strip suffix.")
781 );
782 let rs_file = PathBuf::from(&rs_file_str);
783
784 if rs_file.exists() {
785 let rs_content = fs::read_to_string(&rs_file).expect("Failed to read file.");
786 let tonic_content =
787 fs::read_to_string(tonic_file).expect("Failed to read file.");
788 fs::write(tonic_file, format!("{}\n{}", rs_content, tonic_content))
789 .expect("Failed to write file.");
790 fs::remove_file(&rs_file).unwrap_or_else(|err| {
791 panic!("Failed to remove file: {err:?}: {rs_file:?}")
792 });
793 }
794 } else {
795 let real_tonic_file = PathBuf::from(format!(
796 "{}.tonic.rs",
797 tonic_path_str
798 .strip_suffix(".rs")
799 .expect("Failed to strip suffix.")
800 ));
801 if real_tonic_file.exists() {
802 continue;
803 }
804 fs::rename(tonic_file, &real_tonic_file).unwrap_or_else(|err| {
805 panic!("Failed to rename file: {err:?}: {tonic_file:?} -> {real_tonic_file:?}");
806 });
807 }
808 }
809 }
810
811 // Locate all prost-generated outputs.
812 let mut rust_files = find_generated_rust_files(&out_dir);
813 if rust_files.is_empty() {
814 if expect_rs {
815 panic!("No .rs files were generated by prost.");
816 } else {
817 let file_stem = if package_name.is_empty() {
818 "_"
819 } else {
820 &package_name
821 };
822 let file_stem = format!("{}{}", file_stem, if is_tonic { ".tonic" } else { "" });
823 let empty_rs_file = out_dir.join(format!("{}.rs", file_stem));
824 fs::write(&empty_rs_file, "").expect("Failed to write file.");
825 rust_files.insert(empty_rs_file);
826 }
827 }
828
829 let extern_paths = get_extern_paths(&descriptor_set, &crate_name)
830 .expect("Failed to compute proto package info");
831
832 // Write outputs
833 fs::write(&out_librs, generate_lib_rs(&rust_files, is_tonic)).expect("Failed to write file.");
834 fs::write(
835 package_info_file,
836 extern_paths
837 .into_iter()
838 .map(|(proto_path, rust_path)| format!(".{}=::{}", proto_path, rust_path))
839 .collect::<Vec<_>>()
840 .join("\n"),
841 )
842 .expect("Failed to write file.");
843
844 // Finally run rustfmt on the output lib.rs file
845 if let Some(rustfmt) = rustfmt {
846 let fmt_status = process::Command::new(rustfmt)
847 .arg("--edition")
848 .arg("2021")
849 .arg("--quiet")
850 .arg(&out_librs)
851 .status()
852 .expect("Failed to spawn rustfmt process");
853 if !fmt_status.success() {
854 panic!(
855 "rustfmt failed with exit code: {}",
856 fmt_status.code().expect("Failed to get exit code")
857 );
858 }
859 }
860 }
861
862 /// Rust built-in keywords and reserved keywords.
863 const RUST_KEYWORDS: [&str; 51] = [
864 "abstract", "as", "async", "await", "become", "box", "break", "const", "continue", "crate",
865 "do", "dyn", "else", "enum", "extern", "false", "final", "fn", "for", "if", "impl", "in",
866 "let", "loop", "macro", "match", "mod", "move", "mut", "override", "priv", "pub", "ref",
867 "return", "self", "Self", "static", "struct", "super", "trait", "true", "try", "type",
868 "typeof", "unsafe", "unsized", "use", "virtual", "where", "while", "yield",
869 ];
870
871 /// Returns true if the given string is a Rust keyword.
is_keyword(s: &str) -> bool872 fn is_keyword(s: &str) -> bool {
873 RUST_KEYWORDS.contains(&s)
874 }
875
876 /// Escapes a Rust keyword by prefixing it with `r#`.
escape_keyword(s: String) -> String877 fn escape_keyword(s: String) -> String {
878 if is_keyword(&s) {
879 return format!("r#{s}");
880 }
881 s
882 }
883
884 #[cfg(test)]
885 mod test {
886
887 use super::*;
888
889 use prost_types::{FieldDescriptorProto, ServiceDescriptorProto};
890
891 #[test]
oneof_type_to_extern_paths_test()892 fn oneof_type_to_extern_paths_test() {
893 let oneof_descriptor = OneofDescriptorProto {
894 name: Some("Foo".to_string()),
895 ..OneofDescriptorProto::default()
896 };
897
898 {
899 let mut extern_paths = BTreeMap::new();
900 oneof_type_to_extern_paths(
901 &mut extern_paths,
902 &ProtoPath::from("bar"),
903 &RustModulePath::from("bar"),
904 &oneof_descriptor,
905 );
906
907 assert_eq!(extern_paths.len(), 1);
908 assert_eq!(
909 extern_paths.get(&ProtoPath::from("bar.Foo")),
910 Some(&RustModulePath::from("bar::Foo"))
911 );
912 }
913
914 {
915 let mut extern_paths = BTreeMap::new();
916 oneof_type_to_extern_paths(
917 &mut extern_paths,
918 &ProtoPath::from("bar.baz"),
919 &RustModulePath::from("bar::baz"),
920 &oneof_descriptor,
921 );
922
923 assert_eq!(extern_paths.len(), 1);
924 assert_eq!(
925 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
926 Some(&RustModulePath::from("bar::baz::Foo"))
927 );
928 }
929 }
930
931 #[test]
enum_type_to_extern_paths_test()932 fn enum_type_to_extern_paths_test() {
933 let enum_descriptor = EnumDescriptorProto {
934 name: Some("Foo".to_string()),
935 ..EnumDescriptorProto::default()
936 };
937
938 {
939 let mut extern_paths = BTreeMap::new();
940 enum_type_to_extern_paths(
941 &mut extern_paths,
942 &ProtoPath::from("bar"),
943 &RustModulePath::from("bar"),
944 &enum_descriptor,
945 );
946
947 assert_eq!(extern_paths.len(), 1);
948 assert_eq!(
949 extern_paths.get(&ProtoPath::from("bar.Foo")),
950 Some(&RustModulePath::from("bar::Foo"))
951 );
952 }
953
954 {
955 let mut extern_paths = BTreeMap::new();
956 enum_type_to_extern_paths(
957 &mut extern_paths,
958 &ProtoPath::from("bar.baz"),
959 &RustModulePath::from("bar::baz"),
960 &enum_descriptor,
961 );
962
963 assert_eq!(extern_paths.len(), 1);
964 assert_eq!(
965 extern_paths.get(&ProtoPath::from("bar.baz.Foo")),
966 Some(&RustModulePath::from("bar::baz::Foo"))
967 );
968 }
969 }
970
971 #[test]
message_type_to_extern_paths_test()972 fn message_type_to_extern_paths_test() {
973 let message_descriptor = DescriptorProto {
974 name: Some("Foo".to_string()),
975 nested_type: vec![
976 DescriptorProto {
977 name: Some("Bar".to_string()),
978 ..DescriptorProto::default()
979 },
980 DescriptorProto {
981 name: Some("Nested".to_string()),
982 nested_type: vec![DescriptorProto {
983 name: Some("Baz".to_string()),
984 enum_type: vec![EnumDescriptorProto {
985 name: Some("Chuck".to_string()),
986 ..EnumDescriptorProto::default()
987 }],
988 ..DescriptorProto::default()
989 }],
990 ..DescriptorProto::default()
991 },
992 ],
993 enum_type: vec![EnumDescriptorProto {
994 name: Some("Qux".to_string()),
995 ..EnumDescriptorProto::default()
996 }],
997 ..DescriptorProto::default()
998 };
999
1000 {
1001 let mut extern_paths = BTreeMap::new();
1002 message_type_to_extern_paths(
1003 &mut extern_paths,
1004 &ProtoPath::from("bar"),
1005 &RustModulePath::from("bar"),
1006 &message_descriptor,
1007 );
1008 assert_eq!(extern_paths.len(), 6);
1009 assert_eq!(
1010 extern_paths.get(&ProtoPath::from("bar.Foo")),
1011 Some(&RustModulePath::from("bar::Foo"))
1012 );
1013 assert_eq!(
1014 extern_paths.get(&ProtoPath::from("bar.foo.Bar")),
1015 Some(&RustModulePath::from("bar::foo::Bar"))
1016 );
1017 assert_eq!(
1018 extern_paths.get(&ProtoPath::from("bar.foo.Nested")),
1019 Some(&RustModulePath::from("bar::foo::Nested"))
1020 );
1021 assert_eq!(
1022 extern_paths.get(&ProtoPath::from("bar.foo.nested.Baz")),
1023 Some(&RustModulePath::from("bar::foo::nested::Baz"))
1024 );
1025 }
1026
1027 {
1028 let mut extern_paths = BTreeMap::new();
1029 message_type_to_extern_paths(
1030 &mut extern_paths,
1031 &ProtoPath::from("bar.bob"),
1032 &RustModulePath::from("bar::bob"),
1033 &message_descriptor,
1034 );
1035 assert_eq!(extern_paths.len(), 6);
1036 assert_eq!(
1037 extern_paths.get(&ProtoPath::from("bar.bob.Foo")),
1038 Some(&RustModulePath::from("bar::bob::Foo"))
1039 );
1040 assert_eq!(
1041 extern_paths.get(&ProtoPath::from("bar.bob.foo.Bar")),
1042 Some(&RustModulePath::from("bar::bob::foo::Bar"))
1043 );
1044 assert_eq!(
1045 extern_paths.get(&ProtoPath::from("bar.bob.foo.Nested")),
1046 Some(&RustModulePath::from("bar::bob::foo::Nested"))
1047 );
1048 assert_eq!(
1049 extern_paths.get(&ProtoPath::from("bar.bob.foo.nested.Baz")),
1050 Some(&RustModulePath::from("bar::bob::foo::nested::Baz"))
1051 );
1052 }
1053 }
1054
1055 #[test]
proto_path_test()1056 fn proto_path_test() {
1057 {
1058 let proto_path = ProtoPath::from("");
1059 assert_eq!(proto_path.to_string(), "");
1060 assert_eq!(proto_path.join("foo"), ProtoPath::from("foo"));
1061 }
1062 {
1063 let proto_path = ProtoPath::from("foo");
1064 assert_eq!(proto_path.to_string(), "foo");
1065 assert_eq!(proto_path.join(""), ProtoPath::from("foo"));
1066 }
1067 {
1068 let proto_path = ProtoPath::from("foo");
1069 assert_eq!(proto_path.to_string(), "foo");
1070 assert_eq!(proto_path.join("bar"), ProtoPath::from("foo.bar"));
1071 }
1072 {
1073 let proto_path = ProtoPath::from("foo.bar");
1074 assert_eq!(proto_path.to_string(), "foo.bar");
1075 assert_eq!(proto_path.join("baz"), ProtoPath::from("foo.bar.baz"));
1076 }
1077 {
1078 let proto_path = ProtoPath::from("Foo.baR");
1079 assert_eq!(proto_path.to_string(), "Foo.baR");
1080 assert_eq!(proto_path.join("baz"), ProtoPath::from("Foo.baR.baz"));
1081 }
1082 }
1083
1084 #[test]
rust_module_path_test()1085 fn rust_module_path_test() {
1086 {
1087 let rust_module_path = RustModulePath::from("");
1088 assert_eq!(rust_module_path.to_string(), "");
1089 assert_eq!(rust_module_path.join("foo"), RustModulePath::from("foo"));
1090 }
1091 {
1092 let rust_module_path = RustModulePath::from("foo");
1093 assert_eq!(rust_module_path.to_string(), "foo");
1094 assert_eq!(rust_module_path.join(""), RustModulePath::from("foo"));
1095 }
1096 {
1097 let rust_module_path = RustModulePath::from("foo");
1098 assert_eq!(rust_module_path.to_string(), "foo");
1099 assert_eq!(
1100 rust_module_path.join("bar"),
1101 RustModulePath::from("foo::bar")
1102 );
1103 }
1104 {
1105 let rust_module_path = RustModulePath::from("foo::bar");
1106 assert_eq!(rust_module_path.to_string(), "foo::bar");
1107 assert_eq!(
1108 rust_module_path.join("baz"),
1109 RustModulePath::from("foo::bar::baz")
1110 );
1111 }
1112 }
1113
1114 #[test]
expect_fs_file_to_be_generated_test()1115 fn expect_fs_file_to_be_generated_test() {
1116 {
1117 // Empty descriptor set should create a file.
1118 let descriptor_set = FileDescriptorSet {
1119 file: vec![FileDescriptorProto {
1120 name: Some("foo.proto".to_string()),
1121 ..FileDescriptorProto::default()
1122 }],
1123 };
1124 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1125 }
1126 {
1127 // Descriptor set with only message should create a file.
1128 let descriptor_set = FileDescriptorSet {
1129 file: vec![FileDescriptorProto {
1130 name: Some("foo.proto".to_string()),
1131 message_type: vec![DescriptorProto {
1132 name: Some("Foo".to_string()),
1133 ..DescriptorProto::default()
1134 }],
1135 ..FileDescriptorProto::default()
1136 }],
1137 };
1138 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1139 }
1140 {
1141 // Descriptor set with only enum should create a file.
1142 let descriptor_set = FileDescriptorSet {
1143 file: vec![FileDescriptorProto {
1144 name: Some("foo.proto".to_string()),
1145 enum_type: vec![EnumDescriptorProto {
1146 name: Some("Foo".to_string()),
1147 ..EnumDescriptorProto::default()
1148 }],
1149 ..FileDescriptorProto::default()
1150 }],
1151 };
1152 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1153 }
1154 {
1155 // Descriptor set with only service should create a file.
1156 let descriptor_set = FileDescriptorSet {
1157 file: vec![FileDescriptorProto {
1158 name: Some("foo.proto".to_string()),
1159 service: vec![ServiceDescriptorProto {
1160 name: Some("Foo".to_string()),
1161 ..ServiceDescriptorProto::default()
1162 }],
1163 ..FileDescriptorProto::default()
1164 }],
1165 };
1166 assert!(expect_fs_file_to_be_generated(&descriptor_set));
1167 }
1168 {
1169 // Descriptor set with only extensions should not create a file.
1170 let descriptor_set = FileDescriptorSet {
1171 file: vec![FileDescriptorProto {
1172 name: Some("foo.proto".to_string()),
1173 extension: vec![FieldDescriptorProto {
1174 name: Some("Foo".to_string()),
1175 ..FieldDescriptorProto::default()
1176 }],
1177 ..FileDescriptorProto::default()
1178 }],
1179 };
1180 assert!(!expect_fs_file_to_be_generated(&descriptor_set));
1181 }
1182 }
1183
1184 #[test]
has_services_test()1185 fn has_services_test() {
1186 {
1187 // Empty file should not have services.
1188 let descriptor_set = FileDescriptorSet {
1189 file: vec![FileDescriptorProto {
1190 name: Some("foo.proto".to_string()),
1191 ..FileDescriptorProto::default()
1192 }],
1193 };
1194 assert!(!has_services(&descriptor_set));
1195 }
1196 {
1197 // File with only message should not have services.
1198 let descriptor_set = FileDescriptorSet {
1199 file: vec![FileDescriptorProto {
1200 name: Some("foo.proto".to_string()),
1201 message_type: vec![DescriptorProto {
1202 name: Some("Foo".to_string()),
1203 ..DescriptorProto::default()
1204 }],
1205 ..FileDescriptorProto::default()
1206 }],
1207 };
1208 assert!(!has_services(&descriptor_set));
1209 }
1210 {
1211 // File with services should have services.
1212 let descriptor_set = FileDescriptorSet {
1213 file: vec![FileDescriptorProto {
1214 name: Some("foo.proto".to_string()),
1215 service: vec![ServiceDescriptorProto {
1216 name: Some("Foo".to_string()),
1217 ..ServiceDescriptorProto::default()
1218 }],
1219 ..FileDescriptorProto::default()
1220 }],
1221 };
1222 assert!(has_services(&descriptor_set));
1223 }
1224 }
1225
1226 #[test]
get_package_name_test()1227 fn get_package_name_test() {
1228 let descriptor_set = FileDescriptorSet {
1229 file: vec![FileDescriptorProto {
1230 name: Some("foo.proto".to_string()),
1231 package: Some("foo".to_string()),
1232 ..FileDescriptorProto::default()
1233 }],
1234 };
1235
1236 assert_eq!(get_package_name(&descriptor_set), Some("foo".to_string()));
1237 }
1238
1239 #[test]
is_keyword_test()1240 fn is_keyword_test() {
1241 let non_keywords = [
1242 "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1243 "plugh", "xyzzy", "thud",
1244 ];
1245 for non_keyword in &non_keywords {
1246 assert!(!is_keyword(non_keyword));
1247 }
1248
1249 for keyword in &RUST_KEYWORDS {
1250 assert!(is_keyword(keyword));
1251 }
1252 }
1253
1254 #[test]
escape_keyword_test()1255 fn escape_keyword_test() {
1256 let non_keywords = [
1257 "foo", "bar", "baz", "qux", "quux", "corge", "grault", "garply", "waldo", "fred",
1258 "plugh", "xyzzy", "thud",
1259 ];
1260 for non_keyword in &non_keywords {
1261 assert_eq!(
1262 escape_keyword(non_keyword.to_string()),
1263 non_keyword.to_owned()
1264 );
1265 }
1266
1267 for keyword in &RUST_KEYWORDS {
1268 assert_eq!(
1269 escape_keyword(keyword.to_string()),
1270 format!("r#{}", keyword)
1271 );
1272 }
1273 }
1274 }
1275