1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2
3 use std::io::{Error, ErrorKind, Read};
4 use std::path::Path;
5 use std::{env, fs, io, process::Command, str};
6
7 use derive_new::new;
8 use prost::Message;
9 use prost_build::{Config, Method, Service, ServiceGenerator};
10 use prost_types::FileDescriptorSet;
11
12 use crate::util::{fq_grpc, to_snake_case, MethodType};
13
14 /// Returns the names of all packages compiled.
compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>> where P: AsRef<Path>,15 pub fn compile_protos<P>(protos: &[P], includes: &[P], out_dir: &str) -> io::Result<Vec<String>>
16 where
17 P: AsRef<Path>,
18 {
19 let mut prost_config = Config::new();
20 prost_config.service_generator(Box::new(Generator));
21 prost_config.out_dir(out_dir);
22
23 // Create a file descriptor set for the protocol files.
24 let tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?;
25 std::fs::create_dir_all(tmp.path())?;
26 let descriptor_set = tmp.path().join("prost-descriptor-set");
27
28 let mut cmd = Command::new(prost_build::protoc_from_env());
29 cmd.arg("--include_imports")
30 .arg("--include_source_info")
31 .arg("-o")
32 .arg(&descriptor_set);
33
34 for include in includes {
35 cmd.arg("-I").arg(include.as_ref());
36 }
37
38 // Set the protoc include after the user includes in case the user wants to
39 // override one of the built-in .protos.
40 if let Some(inc) = prost_build::protoc_include_from_env() {
41 cmd.arg("-I").arg(inc);
42 }
43
44 for proto in protos {
45 cmd.arg(proto.as_ref());
46 }
47
48 let output = cmd.output()?;
49 if !output.status.success() {
50 return Err(Error::new(
51 ErrorKind::Other,
52 format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)),
53 ));
54 }
55
56 let mut buf = Vec::new();
57 fs::File::open(descriptor_set)?.read_to_end(&mut buf)?;
58 let descriptor_set = FileDescriptorSet::decode(buf.as_slice())?;
59
60 // Get the package names from the descriptor set.
61 let mut packages: Vec<_> = descriptor_set
62 .file
63 .iter()
64 .filter_map(|f| f.package.clone())
65 .collect();
66 packages.sort();
67 packages.dedup();
68
69 // FIXME(https://github.com/danburkert/prost/pull/155)
70 // Unfortunately we have to forget the above work and use `compile_protos` to
71 // actually generate the Rust code.
72 prost_config.compile_protos(protos, includes)?;
73
74 Ok(packages)
75 }
76
77 struct Generator;
78
79 impl ServiceGenerator for Generator {
generate(&mut self, service: Service, buf: &mut String)80 fn generate(&mut self, service: Service, buf: &mut String) {
81 generate_methods(&service, buf);
82 generate_client(&service, buf);
83 generate_server(&service, buf);
84 }
85 }
86
generate_methods(service: &Service, buf: &mut String)87 fn generate_methods(service: &Service, buf: &mut String) {
88 let service_path = if service.package.is_empty() {
89 format!("/{}", service.proto_name)
90 } else {
91 format!("/{}.{}", service.package, service.proto_name)
92 };
93
94 for method in &service.methods {
95 generate_method(&service.name, &service_path, method, buf);
96 }
97 }
98
const_method_name(service_name: &str, method: &Method) -> String99 fn const_method_name(service_name: &str, method: &Method) -> String {
100 format!(
101 "METHOD_{}_{}",
102 to_snake_case(service_name).to_uppercase(),
103 method.name.to_uppercase()
104 )
105 }
106
generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String)107 fn generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String) {
108 let name = const_method_name(service_name, method);
109 let ty = format!(
110 "{}<{}, {}>",
111 fq_grpc("Method"),
112 method.input_type,
113 method.output_type
114 );
115
116 buf.push_str("const ");
117 buf.push_str(&name);
118 buf.push_str(": ");
119 buf.push_str(&ty);
120 buf.push_str(" = ");
121 generate_method_body(service_path, method, buf);
122 }
123
generate_method_body(service_path: &str, method: &Method, buf: &mut String)124 fn generate_method_body(service_path: &str, method: &Method, buf: &mut String) {
125 let ty = fq_grpc(&MethodType::from_method(method).to_string());
126 let pr_mar = format!(
127 "{} {{ ser: {}, de: {} }}",
128 fq_grpc("Marshaller"),
129 fq_grpc("pr_ser"),
130 fq_grpc("pr_de")
131 );
132
133 buf.push_str(&fq_grpc("Method"));
134 buf.push('{');
135 generate_field_init("ty", &ty, buf);
136 generate_field_init(
137 "name",
138 &format!("\"{}/{}\"", service_path, method.proto_name),
139 buf,
140 );
141 generate_field_init("req_mar", &pr_mar, buf);
142 generate_field_init("resp_mar", &pr_mar, buf);
143 buf.push_str("};\n");
144 }
145
146 // TODO share this code with protobuf codegen
147 impl MethodType {
from_method(method: &Method) -> MethodType148 fn from_method(method: &Method) -> MethodType {
149 match (method.client_streaming, method.server_streaming) {
150 (false, false) => MethodType::Unary,
151 (true, false) => MethodType::ClientStreaming,
152 (false, true) => MethodType::ServerStreaming,
153 (true, true) => MethodType::Duplex,
154 }
155 }
156 }
157
generate_field_init(name: &str, value: &str, buf: &mut String)158 fn generate_field_init(name: &str, value: &str, buf: &mut String) {
159 buf.push_str(name);
160 buf.push_str(": ");
161 buf.push_str(value);
162 buf.push_str(", ");
163 }
164
generate_client(service: &Service, buf: &mut String)165 fn generate_client(service: &Service, buf: &mut String) {
166 let client_name = format!("{}Client", service.name);
167 buf.push_str("#[derive(Clone)]\n");
168 buf.push_str("pub struct ");
169 buf.push_str(&client_name);
170 buf.push_str(" { pub client: ::grpcio::Client }\n");
171
172 buf.push_str("impl ");
173 buf.push_str(&client_name);
174 buf.push_str(" {\n");
175 generate_ctor(&client_name, buf);
176 generate_client_methods(service, buf);
177 generate_spawn(buf);
178 buf.push_str("}\n")
179 }
180
generate_ctor(client_name: &str, buf: &mut String)181 fn generate_ctor(client_name: &str, buf: &mut String) {
182 buf.push_str("pub fn new(channel: ::grpcio::Channel) -> Self { ");
183 buf.push_str(client_name);
184 buf.push_str(" { client: ::grpcio::Client::new(channel) }");
185 buf.push_str("}\n");
186 }
187
generate_client_methods(service: &Service, buf: &mut String)188 fn generate_client_methods(service: &Service, buf: &mut String) {
189 for method in &service.methods {
190 generate_client_method(&service.name, method, buf);
191 }
192 }
193
generate_client_method(service_name: &str, method: &Method, buf: &mut String)194 fn generate_client_method(service_name: &str, method: &Method, buf: &mut String) {
195 let name = &format!(
196 "METHOD_{}_{}",
197 to_snake_case(service_name).to_uppercase(),
198 method.name.to_uppercase()
199 );
200 match MethodType::from_method(method) {
201 MethodType::Unary => {
202 ClientMethod::new(
203 &method.name,
204 true,
205 Some(&method.input_type),
206 false,
207 vec![&method.output_type],
208 "unary_call",
209 name,
210 )
211 .generate(buf);
212 ClientMethod::new(
213 &method.name,
214 false,
215 Some(&method.input_type),
216 false,
217 vec![&method.output_type],
218 "unary_call",
219 name,
220 )
221 .generate(buf);
222 ClientMethod::new(
223 &method.name,
224 true,
225 Some(&method.input_type),
226 true,
227 vec![&format!(
228 "{}<{}>",
229 fq_grpc("ClientUnaryReceiver"),
230 method.output_type
231 )],
232 "unary_call",
233 name,
234 )
235 .generate(buf);
236 ClientMethod::new(
237 &method.name,
238 false,
239 Some(&method.input_type),
240 true,
241 vec![&format!(
242 "{}<{}>",
243 fq_grpc("ClientUnaryReceiver"),
244 method.output_type
245 )],
246 "unary_call",
247 name,
248 )
249 .generate(buf);
250 }
251 MethodType::ClientStreaming => {
252 ClientMethod::new(
253 &method.name,
254 true,
255 None,
256 false,
257 vec![
258 &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
259 &format!(
260 "{}<{}>",
261 fq_grpc("ClientCStreamReceiver"),
262 method.output_type
263 ),
264 ],
265 "client_streaming",
266 name,
267 )
268 .generate(buf);
269 ClientMethod::new(
270 &method.name,
271 false,
272 None,
273 false,
274 vec![
275 &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
276 &format!(
277 "{}<{}>",
278 fq_grpc("ClientCStreamReceiver"),
279 method.output_type
280 ),
281 ],
282 "client_streaming",
283 name,
284 )
285 .generate(buf);
286 }
287 MethodType::ServerStreaming => {
288 ClientMethod::new(
289 &method.name,
290 true,
291 Some(&method.input_type),
292 false,
293 vec![&format!(
294 "{}<{}>",
295 fq_grpc("ClientSStreamReceiver"),
296 method.output_type
297 )],
298 "server_streaming",
299 name,
300 )
301 .generate(buf);
302 ClientMethod::new(
303 &method.name,
304 false,
305 Some(&method.input_type),
306 false,
307 vec![&format!(
308 "{}<{}>",
309 fq_grpc("ClientSStreamReceiver"),
310 method.output_type
311 )],
312 "server_streaming",
313 name,
314 )
315 .generate(buf);
316 }
317 MethodType::Duplex => {
318 ClientMethod::new(
319 &method.name,
320 true,
321 None,
322 false,
323 vec![
324 &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
325 &format!(
326 "{}<{}>",
327 fq_grpc("ClientDuplexReceiver"),
328 method.output_type
329 ),
330 ],
331 "duplex_streaming",
332 name,
333 )
334 .generate(buf);
335 ClientMethod::new(
336 &method.name,
337 false,
338 None,
339 false,
340 vec![
341 &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
342 &format!(
343 "{}<{}>",
344 fq_grpc("ClientDuplexReceiver"),
345 method.output_type
346 ),
347 ],
348 "duplex_streaming",
349 name,
350 )
351 .generate(buf);
352 }
353 }
354 }
355
356 #[derive(new)]
357 struct ClientMethod<'a> {
358 method_name: &'a str,
359 opt: bool,
360 request: Option<&'a str>,
361 r#async: bool,
362 result_types: Vec<&'a str>,
363 inner_method_name: &'a str,
364 data_name: &'a str,
365 }
366
367 impl<'a> ClientMethod<'a> {
generate(&self, buf: &mut String)368 fn generate(&self, buf: &mut String) {
369 buf.push_str("pub fn ");
370
371 buf.push_str(self.method_name);
372 if self.r#async {
373 buf.push_str("_async");
374 }
375 if self.opt {
376 buf.push_str("_opt");
377 }
378
379 buf.push_str("(&self");
380 if let Some(req) = self.request {
381 buf.push_str(", req: &");
382 buf.push_str(req);
383 }
384 if self.opt {
385 buf.push_str(", opt: ");
386 buf.push_str(&fq_grpc("CallOption"));
387 }
388 buf.push_str(") -> ");
389
390 buf.push_str(&fq_grpc("Result"));
391 buf.push('<');
392 if self.result_types.len() != 1 {
393 buf.push('(');
394 }
395 for rt in &self.result_types {
396 buf.push_str(rt);
397 buf.push(',');
398 }
399 if self.result_types.len() != 1 {
400 buf.push(')');
401 }
402 buf.push_str("> { ");
403 if self.opt {
404 self.generate_inner_body(buf);
405 } else {
406 self.generate_opt_body(buf);
407 }
408 buf.push_str(" }\n");
409 }
410
411 // Method delegates to the `_opt` version of the method.
generate_opt_body(&self, buf: &mut String)412 fn generate_opt_body(&self, buf: &mut String) {
413 buf.push_str("self.");
414 buf.push_str(self.method_name);
415 if self.r#async {
416 buf.push_str("_async");
417 }
418 buf.push_str("_opt(");
419 if self.request.is_some() {
420 buf.push_str("req, ");
421 }
422 buf.push_str(&fq_grpc("CallOption::default()"));
423 buf.push(')');
424 }
425
426 // Method delegates to the inner client.
generate_inner_body(&self, buf: &mut String)427 fn generate_inner_body(&self, buf: &mut String) {
428 buf.push_str("self.client.");
429 buf.push_str(self.inner_method_name);
430 if self.r#async {
431 buf.push_str("_async");
432 }
433 buf.push_str("(&");
434 buf.push_str(self.data_name);
435 if self.request.is_some() {
436 buf.push_str(", req");
437 }
438 buf.push_str(", opt)");
439 }
440 }
441
generate_spawn(buf: &mut String)442 fn generate_spawn(buf: &mut String) {
443 buf.push_str(
444 "pub fn spawn<F>(&self, f: F) \
445 where F: ::std::future::Future<Output = ()> + Send + 'static {\
446 self.client.spawn(f)\
447 }\n",
448 );
449 }
450
generate_server(service: &Service, buf: &mut String)451 fn generate_server(service: &Service, buf: &mut String) {
452 buf.push_str("pub trait ");
453 buf.push_str(&service.name);
454 buf.push_str(" {\n");
455 generate_server_methods(service, buf);
456 buf.push_str("}\n");
457
458 buf.push_str("pub fn create_");
459 buf.push_str(&to_snake_case(&service.name));
460 buf.push_str("<S: ");
461 buf.push_str(&service.name);
462 buf.push_str(" + Send + Clone + 'static>(s: S) -> ");
463 buf.push_str(&fq_grpc("Service"));
464 buf.push_str(" {\n");
465 buf.push_str("let mut builder = ::grpcio::ServiceBuilder::new();\n");
466
467 for method in &service.methods[0..service.methods.len() - 1] {
468 buf.push_str("let mut instance = s.clone();\n");
469 generate_method_bind(&service.name, method, buf);
470 }
471
472 buf.push_str("let mut instance = s;\n");
473 generate_method_bind(
474 &service.name,
475 &service.methods[service.methods.len() - 1],
476 buf,
477 );
478
479 buf.push_str("builder.build()\n");
480 buf.push_str("}\n");
481 }
482
generate_server_methods(service: &Service, buf: &mut String)483 fn generate_server_methods(service: &Service, buf: &mut String) {
484 for method in &service.methods {
485 let method_type = MethodType::from_method(method);
486 let request_arg = match method_type {
487 MethodType::Unary | MethodType::ServerStreaming => {
488 format!("req: {}", method.input_type)
489 }
490 MethodType::ClientStreaming | MethodType::Duplex => format!(
491 "stream: {}<{}>",
492 fq_grpc("RequestStream"),
493 method.input_type
494 ),
495 };
496 let response_type = match method_type {
497 MethodType::Unary => "UnarySink",
498 MethodType::ClientStreaming => "ClientStreamingSink",
499 MethodType::ServerStreaming => "ServerStreamingSink",
500 MethodType::Duplex => "DuplexSink",
501 };
502 generate_server_method(method, &request_arg, response_type, buf);
503 }
504 }
505
generate_server_method( method: &Method, request_arg: &str, response_type: &str, buf: &mut String, )506 fn generate_server_method(
507 method: &Method,
508 request_arg: &str,
509 response_type: &str,
510 buf: &mut String,
511 ) {
512 buf.push_str("fn ");
513 buf.push_str(&method.name);
514 buf.push_str("(&mut self, ctx: ");
515 buf.push_str(&fq_grpc("RpcContext"));
516 buf.push_str(", _");
517 buf.push_str(request_arg);
518 buf.push_str(", sink: ");
519 buf.push_str(&fq_grpc(response_type));
520 buf.push('<');
521 buf.push_str(&method.output_type);
522 buf.push('>');
523 buf.push_str(") { grpcio::unimplemented_call!(ctx, sink) }\n");
524 }
525
generate_method_bind(service_name: &str, method: &Method, buf: &mut String)526 fn generate_method_bind(service_name: &str, method: &Method, buf: &mut String) {
527 let add_name = match MethodType::from_method(method) {
528 MethodType::Unary => "add_unary_handler",
529 MethodType::ClientStreaming => "add_client_streaming_handler",
530 MethodType::ServerStreaming => "add_server_streaming_handler",
531 MethodType::Duplex => "add_duplex_streaming_handler",
532 };
533
534 buf.push_str("builder = builder.");
535 buf.push_str(add_name);
536 buf.push_str("(&");
537 buf.push_str(&const_method_name(service_name, method));
538 buf.push_str(", move |ctx, req, resp| instance.");
539 buf.push_str(&method.name);
540 buf.push_str("(ctx, req, resp));\n");
541 }
542
protoc_gen_grpc_rust_main()543 pub fn protoc_gen_grpc_rust_main() {
544 let mut args = env::args();
545 args.next();
546 let (mut protos, mut includes, mut out_dir): (Vec<_>, Vec<_>, _) = Default::default();
547 for arg in args {
548 if let Some(value) = arg.strip_prefix("--protos=") {
549 protos.extend(value.split(",").map(|s| s.to_string()));
550 } else if let Some(value) = arg.strip_prefix("--includes=") {
551 includes.extend(value.split(",").map(|s| s.to_string()));
552 } else if let Some(value) = arg.strip_prefix("--out-dir=") {
553 out_dir = value.to_string();
554 }
555 }
556 if protos.is_empty() {
557 panic!("should at least specify protos to generate");
558 }
559 compile_protos(&protos, &includes, &out_dir).unwrap();
560 }
561