// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. use std::io::{Error, ErrorKind, Read}; use std::path::Path; use std::{env, fs, io, process::Command, str}; use derive_new::new; use prost::Message; use prost_build::{Config, Method, Service, ServiceGenerator}; use prost_types::FileDescriptorSet; use crate::util::{fq_grpc, to_snake_case, MethodType}; /// Returns the names of all packages compiled. pub fn compile_protos

(protos: &[P], includes: &[P], out_dir: &str) -> io::Result> where P: AsRef, { let mut prost_config = Config::new(); prost_config.service_generator(Box::new(Generator)); prost_config.out_dir(out_dir); // Create a file descriptor set for the protocol files. let tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?; std::fs::create_dir_all(tmp.path())?; let descriptor_set = tmp.path().join("prost-descriptor-set"); let mut cmd = Command::new(prost_build::protoc_from_env()); cmd.arg("--include_imports") .arg("--include_source_info") .arg("-o") .arg(&descriptor_set); for include in includes { cmd.arg("-I").arg(include.as_ref()); } // Set the protoc include after the user includes in case the user wants to // override one of the built-in .protos. if let Some(inc) = prost_build::protoc_include_from_env() { cmd.arg("-I").arg(inc); } for proto in protos { cmd.arg(proto.as_ref()); } let output = cmd.output()?; if !output.status.success() { return Err(Error::new( ErrorKind::Other, format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)), )); } let mut buf = Vec::new(); fs::File::open(descriptor_set)?.read_to_end(&mut buf)?; let descriptor_set = FileDescriptorSet::decode(buf.as_slice())?; // Get the package names from the descriptor set. let mut packages: Vec<_> = descriptor_set .file .iter() .filter_map(|f| f.package.clone()) .collect(); packages.sort(); packages.dedup(); // FIXME(https://github.com/danburkert/prost/pull/155) // Unfortunately we have to forget the above work and use `compile_protos` to // actually generate the Rust code. prost_config.compile_protos(protos, includes)?; Ok(packages) } struct Generator; impl ServiceGenerator for Generator { fn generate(&mut self, service: Service, buf: &mut String) { generate_methods(&service, buf); generate_client(&service, buf); generate_server(&service, buf); } } fn generate_methods(service: &Service, buf: &mut String) { let service_path = if service.package.is_empty() { format!("/{}", service.proto_name) } else { format!("/{}.{}", service.package, service.proto_name) }; for method in &service.methods { generate_method(&service.name, &service_path, method, buf); } } fn const_method_name(service_name: &str, method: &Method) -> String { format!( "METHOD_{}_{}", to_snake_case(service_name).to_uppercase(), method.name.to_uppercase() ) } fn generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String) { let name = const_method_name(service_name, method); let ty = format!( "{}<{}, {}>", fq_grpc("Method"), method.input_type, method.output_type ); buf.push_str("const "); buf.push_str(&name); buf.push_str(": "); buf.push_str(&ty); buf.push_str(" = "); generate_method_body(service_path, method, buf); } fn generate_method_body(service_path: &str, method: &Method, buf: &mut String) { let ty = fq_grpc(&MethodType::from_method(method).to_string()); let pr_mar = format!( "{} {{ ser: {}, de: {} }}", fq_grpc("Marshaller"), fq_grpc("pr_ser"), fq_grpc("pr_de") ); buf.push_str(&fq_grpc("Method")); buf.push('{'); generate_field_init("ty", &ty, buf); generate_field_init( "name", &format!("\"{}/{}\"", service_path, method.proto_name), buf, ); generate_field_init("req_mar", &pr_mar, buf); generate_field_init("resp_mar", &pr_mar, buf); buf.push_str("};\n"); } // TODO share this code with protobuf codegen impl MethodType { fn from_method(method: &Method) -> MethodType { match (method.client_streaming, method.server_streaming) { (false, false) => MethodType::Unary, (true, false) => MethodType::ClientStreaming, (false, true) => MethodType::ServerStreaming, (true, true) => MethodType::Duplex, } } } fn generate_field_init(name: &str, value: &str, buf: &mut String) { buf.push_str(name); buf.push_str(": "); buf.push_str(value); buf.push_str(", "); } fn generate_client(service: &Service, buf: &mut String) { let client_name = format!("{}Client", service.name); buf.push_str("#[derive(Clone)]\n"); buf.push_str("pub struct "); buf.push_str(&client_name); buf.push_str(" { pub client: ::grpcio::Client }\n"); buf.push_str("impl "); buf.push_str(&client_name); buf.push_str(" {\n"); generate_ctor(&client_name, buf); generate_client_methods(service, buf); generate_spawn(buf); buf.push_str("}\n") } fn generate_ctor(client_name: &str, buf: &mut String) { buf.push_str("pub fn new(channel: ::grpcio::Channel) -> Self { "); buf.push_str(client_name); buf.push_str(" { client: ::grpcio::Client::new(channel) }"); buf.push_str("}\n"); } fn generate_client_methods(service: &Service, buf: &mut String) { for method in &service.methods { generate_client_method(&service.name, method, buf); } } fn generate_client_method(service_name: &str, method: &Method, buf: &mut String) { let name = &format!( "METHOD_{}_{}", to_snake_case(service_name).to_uppercase(), method.name.to_uppercase() ); match MethodType::from_method(method) { MethodType::Unary => { ClientMethod::new( &method.name, true, Some(&method.input_type), false, vec![&method.output_type], "unary_call", name, ) .generate(buf); ClientMethod::new( &method.name, false, Some(&method.input_type), false, vec![&method.output_type], "unary_call", name, ) .generate(buf); ClientMethod::new( &method.name, true, Some(&method.input_type), true, vec![&format!( "{}<{}>", fq_grpc("ClientUnaryReceiver"), method.output_type )], "unary_call", name, ) .generate(buf); ClientMethod::new( &method.name, false, Some(&method.input_type), true, vec![&format!( "{}<{}>", fq_grpc("ClientUnaryReceiver"), method.output_type )], "unary_call", name, ) .generate(buf); } MethodType::ClientStreaming => { ClientMethod::new( &method.name, true, None, false, vec![ &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type), &format!( "{}<{}>", fq_grpc("ClientCStreamReceiver"), method.output_type ), ], "client_streaming", name, ) .generate(buf); ClientMethod::new( &method.name, false, None, false, vec![ &format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type), &format!( "{}<{}>", fq_grpc("ClientCStreamReceiver"), method.output_type ), ], "client_streaming", name, ) .generate(buf); } MethodType::ServerStreaming => { ClientMethod::new( &method.name, true, Some(&method.input_type), false, vec![&format!( "{}<{}>", fq_grpc("ClientSStreamReceiver"), method.output_type )], "server_streaming", name, ) .generate(buf); ClientMethod::new( &method.name, false, Some(&method.input_type), false, vec![&format!( "{}<{}>", fq_grpc("ClientSStreamReceiver"), method.output_type )], "server_streaming", name, ) .generate(buf); } MethodType::Duplex => { ClientMethod::new( &method.name, true, None, false, vec![ &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type), &format!( "{}<{}>", fq_grpc("ClientDuplexReceiver"), method.output_type ), ], "duplex_streaming", name, ) .generate(buf); ClientMethod::new( &method.name, false, None, false, vec![ &format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type), &format!( "{}<{}>", fq_grpc("ClientDuplexReceiver"), method.output_type ), ], "duplex_streaming", name, ) .generate(buf); } } } #[derive(new)] struct ClientMethod<'a> { method_name: &'a str, opt: bool, request: Option<&'a str>, r#async: bool, result_types: Vec<&'a str>, inner_method_name: &'a str, data_name: &'a str, } impl<'a> ClientMethod<'a> { fn generate(&self, buf: &mut String) { buf.push_str("pub fn "); buf.push_str(self.method_name); if self.r#async { buf.push_str("_async"); } if self.opt { buf.push_str("_opt"); } buf.push_str("(&self"); if let Some(req) = self.request { buf.push_str(", req: &"); buf.push_str(req); } if self.opt { buf.push_str(", opt: "); buf.push_str(&fq_grpc("CallOption")); } buf.push_str(") -> "); buf.push_str(&fq_grpc("Result")); buf.push('<'); if self.result_types.len() != 1 { buf.push('('); } for rt in &self.result_types { buf.push_str(rt); buf.push(','); } if self.result_types.len() != 1 { buf.push(')'); } buf.push_str("> { "); if self.opt { self.generate_inner_body(buf); } else { self.generate_opt_body(buf); } buf.push_str(" }\n"); } // Method delegates to the `_opt` version of the method. fn generate_opt_body(&self, buf: &mut String) { buf.push_str("self."); buf.push_str(self.method_name); if self.r#async { buf.push_str("_async"); } buf.push_str("_opt("); if self.request.is_some() { buf.push_str("req, "); } buf.push_str(&fq_grpc("CallOption::default()")); buf.push(')'); } // Method delegates to the inner client. fn generate_inner_body(&self, buf: &mut String) { buf.push_str("self.client."); buf.push_str(self.inner_method_name); if self.r#async { buf.push_str("_async"); } buf.push_str("(&"); buf.push_str(self.data_name); if self.request.is_some() { buf.push_str(", req"); } buf.push_str(", opt)"); } } fn generate_spawn(buf: &mut String) { buf.push_str( "pub fn spawn(&self, f: F) \ where F: ::std::future::Future + Send + 'static {\ self.client.spawn(f)\ }\n", ); } fn generate_server(service: &Service, buf: &mut String) { buf.push_str("pub trait "); buf.push_str(&service.name); buf.push_str(" {\n"); generate_server_methods(service, buf); buf.push_str("}\n"); buf.push_str("pub fn create_"); buf.push_str(&to_snake_case(&service.name)); buf.push_str("(s: S) -> "); buf.push_str(&fq_grpc("Service")); buf.push_str(" {\n"); buf.push_str("let mut builder = ::grpcio::ServiceBuilder::new();\n"); for method in &service.methods[0..service.methods.len() - 1] { buf.push_str("let mut instance = s.clone();\n"); generate_method_bind(&service.name, method, buf); } buf.push_str("let mut instance = s;\n"); generate_method_bind( &service.name, &service.methods[service.methods.len() - 1], buf, ); buf.push_str("builder.build()\n"); buf.push_str("}\n"); } fn generate_server_methods(service: &Service, buf: &mut String) { for method in &service.methods { let method_type = MethodType::from_method(method); let request_arg = match method_type { MethodType::Unary | MethodType::ServerStreaming => { format!("req: {}", method.input_type) } MethodType::ClientStreaming | MethodType::Duplex => format!( "stream: {}<{}>", fq_grpc("RequestStream"), method.input_type ), }; let response_type = match method_type { MethodType::Unary => "UnarySink", MethodType::ClientStreaming => "ClientStreamingSink", MethodType::ServerStreaming => "ServerStreamingSink", MethodType::Duplex => "DuplexSink", }; generate_server_method(method, &request_arg, response_type, buf); } } fn generate_server_method( method: &Method, request_arg: &str, response_type: &str, buf: &mut String, ) { buf.push_str("fn "); buf.push_str(&method.name); buf.push_str("(&mut self, ctx: "); buf.push_str(&fq_grpc("RpcContext")); buf.push_str(", _"); buf.push_str(request_arg); buf.push_str(", sink: "); buf.push_str(&fq_grpc(response_type)); buf.push('<'); buf.push_str(&method.output_type); buf.push('>'); buf.push_str(") { grpcio::unimplemented_call!(ctx, sink) }\n"); } fn generate_method_bind(service_name: &str, method: &Method, buf: &mut String) { let add_name = match MethodType::from_method(method) { MethodType::Unary => "add_unary_handler", MethodType::ClientStreaming => "add_client_streaming_handler", MethodType::ServerStreaming => "add_server_streaming_handler", MethodType::Duplex => "add_duplex_streaming_handler", }; buf.push_str("builder = builder."); buf.push_str(add_name); buf.push_str("(&"); buf.push_str(&const_method_name(service_name, method)); buf.push_str(", move |ctx, req, resp| instance."); buf.push_str(&method.name); buf.push_str("(ctx, req, resp));\n"); } pub fn protoc_gen_grpc_rust_main() { let mut args = env::args(); args.next(); let (mut protos, mut includes, mut out_dir): (Vec<_>, Vec<_>, _) = Default::default(); for arg in args { if let Some(value) = arg.strip_prefix("--protos=") { protos.extend(value.split(",").map(|s| s.to_string())); } else if let Some(value) = arg.strip_prefix("--includes=") { includes.extend(value.split(",").map(|s| s.to_string())); } else if let Some(value) = arg.strip_prefix("--out-dir=") { out_dir = value.to_string(); } } if protos.is_empty() { panic!("should at least specify protos to generate"); } compile_protos(&protos, &includes, &out_dir).unwrap(); }