1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2 
3 use std::io::{Error, ErrorKind, Read};
4 use std::path::Path;
5 use std::{env, fs, io, process::Command, str};
6 
7 use derive_new::new;
8 use prost::Message;
9 use prost_build::{Config, Method, Service, ServiceGenerator};
10 use prost_types::FileDescriptorSet;
11 
12 use crate::util::{fq_grpc, to_snake_case, MethodType};
13 
14 /// Returns the names of all packages compiled.
compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>> where P: AsRef<Path>,15 pub fn compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>>
16 where
17     P: AsRef<Path>,
18 {
19     let mut prost_config = Config::new();
20     prost_config.service_generator(Box::new(Generator));
21     prost_config.out_dir(out_dir);
22 
23     // Create a file descriptor set for the protocol files.
24     let tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?;
25     std::fs::create_dir_all(tmp.path())?;
26     let descriptor_set = tmp.path().join("prost-descriptor-set");
27 
28     let mut cmd = Command::new(prost_build::protoc_from_env());
29     cmd.arg("--include_imports")
30         .arg("--include_source_info")
31         .arg("-o")
32         .arg(&descriptor_set);
33 
34     for include in includes {
35         cmd.arg("-I").arg(include.as_ref());
36     }
37 
38     // Set the protoc include after the user includes in case the user wants to
39     // override one of the built-in .protos.
40     if let Some(inc) = prost_build::protoc_include_from_env() {
41         cmd.arg("-I").arg(inc);
42     }
43 
44     for proto in protos {
45         cmd.arg(proto.as_ref());
46     }
47 
48     let output = cmd.output()?;
49     if !output.status.success() {
50         return Err(Error::new(
51             ErrorKind::Other,
52             format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)),
53         ));
54     }
55 
56     let mut buf = Vec::new();
57     fs::File::open(descriptor_set)?.read_to_end(&mut buf)?;
58     let descriptor_set = FileDescriptorSet::decode(buf.as_slice())?;
59 
60     // Get the package names from the descriptor set.
61     let mut packages: Vec<_> = descriptor_set
62         .file
63         .iter()
64         .filter_map(|f| f.package.clone())
65         .collect();
66     packages.sort();
67     packages.dedup();
68 
69     // FIXME(https://github.com/danburkert/prost/pull/155)
70     // Unfortunately we have to forget the above work and use `compile_protos` to
71     // actually generate the Rust code.
72     prost_config.compile_protos(protos, includes)?;
73 
74     Ok(packages)
75 }
76 
77 struct Generator;
78 
79 impl ServiceGenerator for Generator {
generate(&mut self, service: Service, buf: &mut String)80     fn generate(&mut self, service: Service, buf: &mut String) {
81         generate_methods(&service, buf);
82         generate_client(&service, buf);
83         generate_server(&service, buf);
84     }
85 }
86 
generate_methods(service: &Service, buf: &mut String)87 fn generate_methods(service: &Service, buf: &mut String) {
88     let service_path = if service.package.is_empty() {
89         format!("/{}", service.proto_name)
90     } else {
91         format!("/{}.{}", service.package, service.proto_name)
92     };
93 
94     for method in &service.methods {
95         generate_method(&service.name, &service_path, method, buf);
96     }
97 }
98 
const_method_name(service_name: &str, method: &Method) -> String99 fn const_method_name(service_name: &str, method: &Method) -> String {
100     format!(
101         "METHOD_{}_{}",
102         to_snake_case(service_name).to_uppercase(),
103         method.name.to_uppercase()
104     )
105 }
106 
generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String)107 fn generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String) {
108     let name = const_method_name(service_name, method);
109     let ty = format!(
110         "{}<{}, {}>",
111         fq_grpc("Method"),
112         method.input_type,
113         method.output_type
114     );
115 
116     buf.push_str("const ");
117     buf.push_str(&name);
118     buf.push_str(": ");
119     buf.push_str(&ty);
120     buf.push_str(" = ");
121     generate_method_body(service_path, method, buf);
122 }
123 
generate_method_body(service_path: &str, method: &Method, buf: &mut String)124 fn generate_method_body(service_path: &str, method: &Method, buf: &mut String) {
125     let ty = fq_grpc(&MethodType::from_method(method).to_string());
126     let pr_mar = format!(
127         "{} {{ ser: {}, de: {} }}",
128         fq_grpc("Marshaller"),
129         fq_grpc("pr_ser"),
130         fq_grpc("pr_de")
131     );
132 
133     buf.push_str(&fq_grpc("Method"));
134     buf.push('{');
135     generate_field_init("ty", &ty, buf);
136     generate_field_init(
137         "name",
138         &format!("\"{}/{}\"", service_path, method.proto_name),
139         buf,
140     );
141     generate_field_init("req_mar", &pr_mar, buf);
142     generate_field_init("resp_mar", &pr_mar, buf);
143     buf.push_str("};\n");
144 }
145 
146 // TODO share this code with protobuf codegen
147 impl MethodType {
from_method(method: &Method) -> MethodType148     fn from_method(method: &Method) -> MethodType {
149         match (method.client_streaming, method.server_streaming) {
150             (false, false) => MethodType::Unary,
151             (true, false) => MethodType::ClientStreaming,
152             (false, true) => MethodType::ServerStreaming,
153             (true, true) => MethodType::Duplex,
154         }
155     }
156 }
157 
generate_field_init(name: &str, value: &str, buf: &mut String)158 fn generate_field_init(name: &str, value: &str, buf: &mut String) {
159     buf.push_str(name);
160     buf.push_str(": ");
161     buf.push_str(value);
162     buf.push_str(", ");
163 }
164 
generate_client(service: &Service, buf: &mut String)165 fn generate_client(service: &Service, buf: &mut String) {
166     let client_name = format!("{}Client", service.name);
167     buf.push_str("#[derive(Clone)]\n");
168     buf.push_str("pub struct ");
169     buf.push_str(&client_name);
170     buf.push_str(" { pub client: ::grpcio::Client }\n");
171 
172     buf.push_str("impl ");
173     buf.push_str(&client_name);
174     buf.push_str(" {\n");
175     generate_ctor(&client_name, buf);
176     generate_client_methods(service, buf);
177     generate_spawn(buf);
178     buf.push_str("}\n")
179 }
180 
generate_ctor(client_name: &str, buf: &mut String)181 fn generate_ctor(client_name: &str, buf: &mut String) {
182     buf.push_str("pub fn new(channel: ::grpcio::Channel) -> Self { ");
183     buf.push_str(client_name);
184     buf.push_str(" { client: ::grpcio::Client::new(channel) }");
185     buf.push_str("}\n");
186 }
187 
generate_client_methods(service: &Service, buf: &mut String)188 fn generate_client_methods(service: &Service, buf: &mut String) {
189     for method in &service.methods {
190         generate_client_method(&service.name, method, buf);
191     }
192 }
193 
generate_client_method(service_name: &str, method: &Method, buf: &mut String)194 fn generate_client_method(service_name: &str, method: &Method, buf: &mut String) {
195     let name = &format!(
196         "METHOD_{}_{}",
197         to_snake_case(service_name).to_uppercase(),
198         method.name.to_uppercase()
199     );
200     match MethodType::from_method(method) {
201         MethodType::Unary => {
202             ClientMethod::new(
203                 &method.name,
204                 true,
205                 Some(&method.input_type),
206                 false,
207                 vec![&method.output_type],
208                 "unary_call",
209                 name,
210             )
211             .generate(buf);
212             ClientMethod::new(
213                 &method.name,
214                 false,
215                 Some(&method.input_type),
216                 false,
217                 vec![&method.output_type],
218                 "unary_call",
219                 name,
220             )
221             .generate(buf);
222             ClientMethod::new(
223                 &method.name,
224                 true,
225                 Some(&method.input_type),
226                 true,
227                 vec![&format!(
228                     "{}<{}>",
229                     fq_grpc("ClientUnaryReceiver"),
230                     method.output_type
231                 )],
232                 "unary_call",
233                 name,
234             )
235             .generate(buf);
236             ClientMethod::new(
237                 &method.name,
238                 false,
239                 Some(&method.input_type),
240                 true,
241                 vec![&format!(
242                     "{}<{}>",
243                     fq_grpc("ClientUnaryReceiver"),
244                     method.output_type
245                 )],
246                 "unary_call",
247                 name,
248             )
249             .generate(buf);
250         }
251         MethodType::ClientStreaming => {
252             ClientMethod::new(
253                 &method.name,
254                 true,
255                 None,
256                 false,
257                 vec![
258                     &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
259                     &format!(
260                         "{}<{}>",
261                         fq_grpc("ClientCStreamReceiver"),
262                         method.output_type
263                     ),
264                 ],
265                 "client_streaming",
266                 name,
267             )
268             .generate(buf);
269             ClientMethod::new(
270                 &method.name,
271                 false,
272                 None,
273                 false,
274                 vec![
275                     &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
276                     &format!(
277                         "{}<{}>",
278                         fq_grpc("ClientCStreamReceiver"),
279                         method.output_type
280                     ),
281                 ],
282                 "client_streaming",
283                 name,
284             )
285             .generate(buf);
286         }
287         MethodType::ServerStreaming => {
288             ClientMethod::new(
289                 &method.name,
290                 true,
291                 Some(&method.input_type),
292                 false,
293                 vec![&format!(
294                     "{}<{}>",
295                     fq_grpc("ClientSStreamReceiver"),
296                     method.output_type
297                 )],
298                 "server_streaming",
299                 name,
300             )
301             .generate(buf);
302             ClientMethod::new(
303                 &method.name,
304                 false,
305                 Some(&method.input_type),
306                 false,
307                 vec![&format!(
308                     "{}<{}>",
309                     fq_grpc("ClientSStreamReceiver"),
310                     method.output_type
311                 )],
312                 "server_streaming",
313                 name,
314             )
315             .generate(buf);
316         }
317         MethodType::Duplex => {
318             ClientMethod::new(
319                 &method.name,
320                 true,
321                 None,
322                 false,
323                 vec![
324                     &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
325                     &format!(
326                         "{}<{}>",
327                         fq_grpc("ClientDuplexReceiver"),
328                         method.output_type
329                     ),
330                 ],
331                 "duplex_streaming",
332                 name,
333             )
334             .generate(buf);
335             ClientMethod::new(
336                 &method.name,
337                 false,
338                 None,
339                 false,
340                 vec![
341                     &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
342                     &format!(
343                         "{}<{}>",
344                         fq_grpc("ClientDuplexReceiver"),
345                         method.output_type
346                     ),
347                 ],
348                 "duplex_streaming",
349                 name,
350             )
351             .generate(buf);
352         }
353     }
354 }
355 
356 #[derive(new)]
357 struct ClientMethod<'a> {
358     method_name: &'a str,
359     opt: bool,
360     request: Option<&'a str>,
361     r#async: bool,
362     result_types: Vec<&'a str>,
363     inner_method_name: &'a str,
364     data_name: &'a str,
365 }
366 
367 impl<'a> ClientMethod<'a> {
generate(&self, buf: &mut String)368     fn generate(&self, buf: &mut String) {
369         buf.push_str("pub fn ");
370 
371         buf.push_str(self.method_name);
372         if self.r#async {
373             buf.push_str("_async");
374         }
375         if self.opt {
376             buf.push_str("_opt");
377         }
378 
379         buf.push_str("(&self");
380         if let Some(req) = self.request {
381             buf.push_str(", req: &");
382             buf.push_str(req);
383         }
384         if self.opt {
385             buf.push_str(", opt: ");
386             buf.push_str(&fq_grpc("CallOption"));
387         }
388         buf.push_str(") -> ");
389 
390         buf.push_str(&fq_grpc("Result"));
391         buf.push('<');
392         if self.result_types.len() != 1 {
393             buf.push('(');
394         }
395         for rt in &self.result_types {
396             buf.push_str(rt);
397             buf.push(',');
398         }
399         if self.result_types.len() != 1 {
400             buf.push(')');
401         }
402         buf.push_str("> { ");
403         if self.opt {
404             self.generate_inner_body(buf);
405         } else {
406             self.generate_opt_body(buf);
407         }
408         buf.push_str(" }\n");
409     }
410 
411     // Method delegates to the `_opt` version of the method.
generate_opt_body(&self, buf: &mut String)412     fn generate_opt_body(&self, buf: &mut String) {
413         buf.push_str("self.");
414         buf.push_str(self.method_name);
415         if self.r#async {
416             buf.push_str("_async");
417         }
418         buf.push_str("_opt(");
419         if self.request.is_some() {
420             buf.push_str("req, ");
421         }
422         buf.push_str(&fq_grpc("CallOption::default()"));
423         buf.push(')');
424     }
425 
426     // Method delegates to the inner client.
generate_inner_body(&self, buf: &mut String)427     fn generate_inner_body(&self, buf: &mut String) {
428         buf.push_str("self.client.");
429         buf.push_str(self.inner_method_name);
430         if self.r#async {
431             buf.push_str("_async");
432         }
433         buf.push_str("(&");
434         buf.push_str(self.data_name);
435         if self.request.is_some() {
436             buf.push_str(", req");
437         }
438         buf.push_str(", opt)");
439     }
440 }
441 
generate_spawn(buf: &mut String)442 fn generate_spawn(buf: &mut String) {
443     buf.push_str(
444         "pub fn spawn<F>(&self, f: F) \
445          where F: ::std::future::Future<Output = ()> + Send + 'static {\
446          self.client.spawn(f)\
447          }\n",
448     );
449 }
450 
generate_server(service: &Service, buf: &mut String)451 fn generate_server(service: &Service, buf: &mut String) {
452     buf.push_str("pub trait ");
453     buf.push_str(&service.name);
454     buf.push_str(" {\n");
455     generate_server_methods(service, buf);
456     buf.push_str("}\n");
457 
458     buf.push_str("pub fn create_");
459     buf.push_str(&to_snake_case(&service.name));
460     buf.push_str("<S: ");
461     buf.push_str(&service.name);
462     buf.push_str(" + Send + Clone + 'static>(s: S) -> ");
463     buf.push_str(&fq_grpc("Service"));
464     buf.push_str(" {\n");
465     buf.push_str("let mut builder = ::grpcio::ServiceBuilder::new();\n");
466 
467     for method in &service.methods[0..service.methods.len() - 1] {
468         buf.push_str("let mut instance = s.clone();\n");
469         generate_method_bind(&service.name, method, buf);
470     }
471 
472     buf.push_str("let mut instance = s;\n");
473     generate_method_bind(
474         &service.name,
475         &service.methods[service.methods.len() - 1],
476         buf,
477     );
478 
479     buf.push_str("builder.build()\n");
480     buf.push_str("}\n");
481 }
482 
generate_server_methods(service: &Service, buf: &mut String)483 fn generate_server_methods(service: &Service, buf: &mut String) {
484     for method in &service.methods {
485         let method_type = MethodType::from_method(method);
486         let request_arg = match method_type {
487             MethodType::Unary | MethodType::ServerStreaming => {
488                 format!("req: {}", method.input_type)
489             }
490             MethodType::ClientStreaming | MethodType::Duplex => format!(
491                 "stream: {}<{}>",
492                 fq_grpc("RequestStream"),
493                 method.input_type
494             ),
495         };
496         let response_type = match method_type {
497             MethodType::Unary => "UnarySink",
498             MethodType::ClientStreaming => "ClientStreamingSink",
499             MethodType::ServerStreaming => "ServerStreamingSink",
500             MethodType::Duplex => "DuplexSink",
501         };
502         generate_server_method(method, &request_arg, response_type, buf);
503     }
504 }
505 
generate_server_method( method: &Method, request_arg: &str, response_type: &str, buf: &mut String, )506 fn generate_server_method(
507     method: &Method,
508     request_arg: &str,
509     response_type: &str,
510     buf: &mut String,
511 ) {
512     buf.push_str("fn ");
513     buf.push_str(&method.name);
514     buf.push_str("(&mut self, ctx: ");
515     buf.push_str(&fq_grpc("RpcContext"));
516     buf.push_str(", _");
517     buf.push_str(request_arg);
518     buf.push_str(", sink: ");
519     buf.push_str(&fq_grpc(response_type));
520     buf.push('<');
521     buf.push_str(&method.output_type);
522     buf.push('>');
523     buf.push_str(") { grpcio::unimplemented_call!(ctx, sink) }\n");
524 }
525 
generate_method_bind(service_name: &str, method: &Method, buf: &mut String)526 fn generate_method_bind(service_name: &str, method: &Method, buf: &mut String) {
527     let add_name = match MethodType::from_method(method) {
528         MethodType::Unary => "add_unary_handler",
529         MethodType::ClientStreaming => "add_client_streaming_handler",
530         MethodType::ServerStreaming => "add_server_streaming_handler",
531         MethodType::Duplex => "add_duplex_streaming_handler",
532     };
533 
534     buf.push_str("builder = builder.");
535     buf.push_str(add_name);
536     buf.push_str("(&");
537     buf.push_str(&const_method_name(service_name, method));
538     buf.push_str(", move |ctx, req, resp| instance.");
539     buf.push_str(&method.name);
540     buf.push_str("(ctx, req, resp));\n");
541 }
542 
protoc_gen_grpc_rust_main()543 pub fn protoc_gen_grpc_rust_main() {
544     let mut args = env::args();
545     args.next();
546     let (mut protos, mut includes, mut out_dir): (Vec<_>, Vec<_>, _) = Default::default();
547     for arg in args {
548         if let Some(value) = arg.strip_prefix("--protos=") {
549             protos.extend(value.split(",").map(|s| s.to_string()));
550         } else if let Some(value) = arg.strip_prefix("--includes=") {
551             includes.extend(value.split(",").map(|s| s.to_string()));
552         } else if let Some(value) = arg.strip_prefix("--out-dir=") {
553             out_dir = value.to_string();
554         }
555     }
556     if protos.is_empty() {
557         panic!("should at least specify protos to generate");
558     }
559     compile_protos(&protos, &includes, &out_dir).unwrap();
560 }
561