// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
use std::io::{Error, ErrorKind, Read};
use std::path::Path;
use std::{env, fs, io, process::Command, str};
use derive_new::new;
use prost::Message;
use prost_build::{Config, Method, Service, ServiceGenerator};
use prost_types::FileDescriptorSet;
use crate::util::{fq_grpc, to_snake_case, MethodType};
/// Returns the names of all packages compiled.
pub fn compile_protos
(protos: &[P], includes: &[P], out_dir: &str) -> io::Result>
where
P: AsRef,
{
let mut prost_config = Config::new();
prost_config.service_generator(Box::new(Generator));
prost_config.out_dir(out_dir);
// Create a file descriptor set for the protocol files.
let tmp = tempfile::Builder::new().prefix("prost-build").tempdir()?;
std::fs::create_dir_all(tmp.path())?;
let descriptor_set = tmp.path().join("prost-descriptor-set");
let mut cmd = Command::new(prost_build::protoc_from_env());
cmd.arg("--include_imports")
.arg("--include_source_info")
.arg("-o")
.arg(&descriptor_set);
for include in includes {
cmd.arg("-I").arg(include.as_ref());
}
// Set the protoc include after the user includes in case the user wants to
// override one of the built-in .protos.
if let Some(inc) = prost_build::protoc_include_from_env() {
cmd.arg("-I").arg(inc);
}
for proto in protos {
cmd.arg(proto.as_ref());
}
let output = cmd.output()?;
if !output.status.success() {
return Err(Error::new(
ErrorKind::Other,
format!("protoc failed: {}", String::from_utf8_lossy(&output.stderr)),
));
}
let mut buf = Vec::new();
fs::File::open(descriptor_set)?.read_to_end(&mut buf)?;
let descriptor_set = FileDescriptorSet::decode(buf.as_slice())?;
// Get the package names from the descriptor set.
let mut packages: Vec<_> = descriptor_set
.file
.iter()
.filter_map(|f| f.package.clone())
.collect();
packages.sort();
packages.dedup();
// FIXME(https://github.com/danburkert/prost/pull/155)
// Unfortunately we have to forget the above work and use `compile_protos` to
// actually generate the Rust code.
prost_config.compile_protos(protos, includes)?;
Ok(packages)
}
struct Generator;
impl ServiceGenerator for Generator {
fn generate(&mut self, service: Service, buf: &mut String) {
generate_methods(&service, buf);
generate_client(&service, buf);
generate_server(&service, buf);
}
}
fn generate_methods(service: &Service, buf: &mut String) {
let service_path = if service.package.is_empty() {
format!("/{}", service.proto_name)
} else {
format!("/{}.{}", service.package, service.proto_name)
};
for method in &service.methods {
generate_method(&service.name, &service_path, method, buf);
}
}
fn const_method_name(service_name: &str, method: &Method) -> String {
format!(
"METHOD_{}_{}",
to_snake_case(service_name).to_uppercase(),
method.name.to_uppercase()
)
}
fn generate_method(service_name: &str, service_path: &str, method: &Method, buf: &mut String) {
let name = const_method_name(service_name, method);
let ty = format!(
"{}<{}, {}>",
fq_grpc("Method"),
method.input_type,
method.output_type
);
buf.push_str("const ");
buf.push_str(&name);
buf.push_str(": ");
buf.push_str(&ty);
buf.push_str(" = ");
generate_method_body(service_path, method, buf);
}
fn generate_method_body(service_path: &str, method: &Method, buf: &mut String) {
let ty = fq_grpc(&MethodType::from_method(method).to_string());
let pr_mar = format!(
"{} {{ ser: {}, de: {} }}",
fq_grpc("Marshaller"),
fq_grpc("pr_ser"),
fq_grpc("pr_de")
);
buf.push_str(&fq_grpc("Method"));
buf.push('{');
generate_field_init("ty", &ty, buf);
generate_field_init(
"name",
&format!("\"{}/{}\"", service_path, method.proto_name),
buf,
);
generate_field_init("req_mar", &pr_mar, buf);
generate_field_init("resp_mar", &pr_mar, buf);
buf.push_str("};\n");
}
// TODO share this code with protobuf codegen
impl MethodType {
fn from_method(method: &Method) -> MethodType {
match (method.client_streaming, method.server_streaming) {
(false, false) => MethodType::Unary,
(true, false) => MethodType::ClientStreaming,
(false, true) => MethodType::ServerStreaming,
(true, true) => MethodType::Duplex,
}
}
}
fn generate_field_init(name: &str, value: &str, buf: &mut String) {
buf.push_str(name);
buf.push_str(": ");
buf.push_str(value);
buf.push_str(", ");
}
fn generate_client(service: &Service, buf: &mut String) {
let client_name = format!("{}Client", service.name);
buf.push_str("#[derive(Clone)]\n");
buf.push_str("pub struct ");
buf.push_str(&client_name);
buf.push_str(" { pub client: ::grpcio::Client }\n");
buf.push_str("impl ");
buf.push_str(&client_name);
buf.push_str(" {\n");
generate_ctor(&client_name, buf);
generate_client_methods(service, buf);
generate_spawn(buf);
buf.push_str("}\n")
}
fn generate_ctor(client_name: &str, buf: &mut String) {
buf.push_str("pub fn new(channel: ::grpcio::Channel) -> Self { ");
buf.push_str(client_name);
buf.push_str(" { client: ::grpcio::Client::new(channel) }");
buf.push_str("}\n");
}
fn generate_client_methods(service: &Service, buf: &mut String) {
for method in &service.methods {
generate_client_method(&service.name, method, buf);
}
}
fn generate_client_method(service_name: &str, method: &Method, buf: &mut String) {
let name = &format!(
"METHOD_{}_{}",
to_snake_case(service_name).to_uppercase(),
method.name.to_uppercase()
);
match MethodType::from_method(method) {
MethodType::Unary => {
ClientMethod::new(
&method.name,
true,
Some(&method.input_type),
false,
vec![&method.output_type],
"unary_call",
name,
)
.generate(buf);
ClientMethod::new(
&method.name,
false,
Some(&method.input_type),
false,
vec![&method.output_type],
"unary_call",
name,
)
.generate(buf);
ClientMethod::new(
&method.name,
true,
Some(&method.input_type),
true,
vec![&format!(
"{}<{}>",
fq_grpc("ClientUnaryReceiver"),
method.output_type
)],
"unary_call",
name,
)
.generate(buf);
ClientMethod::new(
&method.name,
false,
Some(&method.input_type),
true,
vec![&format!(
"{}<{}>",
fq_grpc("ClientUnaryReceiver"),
method.output_type
)],
"unary_call",
name,
)
.generate(buf);
}
MethodType::ClientStreaming => {
ClientMethod::new(
&method.name,
true,
None,
false,
vec![
&format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
&format!(
"{}<{}>",
fq_grpc("ClientCStreamReceiver"),
method.output_type
),
],
"client_streaming",
name,
)
.generate(buf);
ClientMethod::new(
&method.name,
false,
None,
false,
vec![
&format!("{}<{}>", fq_grpc("ClientCStreamSender"), method.input_type),
&format!(
"{}<{}>",
fq_grpc("ClientCStreamReceiver"),
method.output_type
),
],
"client_streaming",
name,
)
.generate(buf);
}
MethodType::ServerStreaming => {
ClientMethod::new(
&method.name,
true,
Some(&method.input_type),
false,
vec![&format!(
"{}<{}>",
fq_grpc("ClientSStreamReceiver"),
method.output_type
)],
"server_streaming",
name,
)
.generate(buf);
ClientMethod::new(
&method.name,
false,
Some(&method.input_type),
false,
vec![&format!(
"{}<{}>",
fq_grpc("ClientSStreamReceiver"),
method.output_type
)],
"server_streaming",
name,
)
.generate(buf);
}
MethodType::Duplex => {
ClientMethod::new(
&method.name,
true,
None,
false,
vec![
&format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
&format!(
"{}<{}>",
fq_grpc("ClientDuplexReceiver"),
method.output_type
),
],
"duplex_streaming",
name,
)
.generate(buf);
ClientMethod::new(
&method.name,
false,
None,
false,
vec![
&format!("{}<{}>", fq_grpc("ClientDuplexSender"), method.input_type),
&format!(
"{}<{}>",
fq_grpc("ClientDuplexReceiver"),
method.output_type
),
],
"duplex_streaming",
name,
)
.generate(buf);
}
}
}
#[derive(new)]
struct ClientMethod<'a> {
method_name: &'a str,
opt: bool,
request: Option<&'a str>,
r#async: bool,
result_types: Vec<&'a str>,
inner_method_name: &'a str,
data_name: &'a str,
}
impl<'a> ClientMethod<'a> {
fn generate(&self, buf: &mut String) {
buf.push_str("pub fn ");
buf.push_str(self.method_name);
if self.r#async {
buf.push_str("_async");
}
if self.opt {
buf.push_str("_opt");
}
buf.push_str("(&self");
if let Some(req) = self.request {
buf.push_str(", req: &");
buf.push_str(req);
}
if self.opt {
buf.push_str(", opt: ");
buf.push_str(&fq_grpc("CallOption"));
}
buf.push_str(") -> ");
buf.push_str(&fq_grpc("Result"));
buf.push('<');
if self.result_types.len() != 1 {
buf.push('(');
}
for rt in &self.result_types {
buf.push_str(rt);
buf.push(',');
}
if self.result_types.len() != 1 {
buf.push(')');
}
buf.push_str("> { ");
if self.opt {
self.generate_inner_body(buf);
} else {
self.generate_opt_body(buf);
}
buf.push_str(" }\n");
}
// Method delegates to the `_opt` version of the method.
fn generate_opt_body(&self, buf: &mut String) {
buf.push_str("self.");
buf.push_str(self.method_name);
if self.r#async {
buf.push_str("_async");
}
buf.push_str("_opt(");
if self.request.is_some() {
buf.push_str("req, ");
}
buf.push_str(&fq_grpc("CallOption::default()"));
buf.push(')');
}
// Method delegates to the inner client.
fn generate_inner_body(&self, buf: &mut String) {
buf.push_str("self.client.");
buf.push_str(self.inner_method_name);
if self.r#async {
buf.push_str("_async");
}
buf.push_str("(&");
buf.push_str(self.data_name);
if self.request.is_some() {
buf.push_str(", req");
}
buf.push_str(", opt)");
}
}
fn generate_spawn(buf: &mut String) {
buf.push_str(
"pub fn spawn(&self, f: F) \
where F: ::std::future::Future