1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2
3 use std::cell::UnsafeCell;
4 use std::collections::HashMap;
5 use std::ffi::CString;
6 use std::fmt::{self, Debug, Formatter};
7 use std::future::Future;
8 use std::pin::Pin;
9 use std::ptr;
10 use std::sync::atomic::{AtomicBool, Ordering};
11 use std::sync::{Arc, Mutex};
12 use std::task::{Context, Poll};
13
14 use crate::grpc_sys::{self, grpc_call_error, grpc_server};
15 use futures_util::ready;
16
17 use crate::call::server::*;
18 use crate::call::{MessageReader, Method, MethodType};
19 use crate::channel::ChannelArgs;
20 use crate::cq::CompletionQueue;
21 use crate::env::Environment;
22 use crate::error::{Error, Result};
23 use crate::task::{CallTag, CqFuture};
24 use crate::RpcStatus;
25 use crate::{RpcContext, ServerCredentials};
26
27 const DEFAULT_REQUEST_SLOTS_PER_CQ: usize = 1024;
28
29 /// An RPC call holder.
30 #[derive(Clone)]
31 pub struct Handler<F> {
32 method_type: MethodType,
33 cb: F,
34 }
35
36 impl<F> Handler<F> {
new(method_type: MethodType, cb: F) -> Handler<F>37 pub fn new(method_type: MethodType, cb: F) -> Handler<F> {
38 Handler { method_type, cb }
39 }
40 }
41
42 pub trait CloneableHandler: Send {
handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>)43 fn handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>);
box_clone(&self) -> Box<dyn CloneableHandler>44 fn box_clone(&self) -> Box<dyn CloneableHandler>;
method_type(&self) -> MethodType45 fn method_type(&self) -> MethodType;
46 }
47
48 impl<F: 'static> CloneableHandler for Handler<F>
49 where
50 F: FnMut(RpcContext<'_>, Option<MessageReader>) + Send + Clone,
51 {
52 #[inline]
handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>)53 fn handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>) {
54 (self.cb)(ctx, reqs)
55 }
56
57 #[inline]
box_clone(&self) -> Box<dyn CloneableHandler>58 fn box_clone(&self) -> Box<dyn CloneableHandler> {
59 Box::new(self.clone())
60 }
61
62 #[inline]
method_type(&self) -> MethodType63 fn method_type(&self) -> MethodType {
64 self.method_type
65 }
66 }
67
68 /// [`Service`] factory in order to configure the properties.
69 ///
70 /// Use it to build a service which can be registered to a server.
71 pub struct ServiceBuilder {
72 handlers: HashMap<&'static [u8], BoxHandler>,
73 }
74
75 impl ServiceBuilder {
76 /// Initialize a new [`ServiceBuilder`].
new() -> ServiceBuilder77 pub fn new() -> ServiceBuilder {
78 ServiceBuilder {
79 handlers: HashMap::new(),
80 }
81 }
82
83 /// Add a unary RPC call handler.
add_unary_handler<Req, Resp, F>( mut self, method: &Method<Req, Resp>, mut handler: F, ) -> ServiceBuilder where Req: 'static, Resp: 'static, F: FnMut(RpcContext<'_>, Req, UnarySink<Resp>) + Send + Clone + 'static,84 pub fn add_unary_handler<Req, Resp, F>(
85 mut self,
86 method: &Method<Req, Resp>,
87 mut handler: F,
88 ) -> ServiceBuilder
89 where
90 Req: 'static,
91 Resp: 'static,
92 F: FnMut(RpcContext<'_>, Req, UnarySink<Resp>) + Send + Clone + 'static,
93 {
94 let (ser, de) = (method.resp_ser(), method.req_de());
95 let h = move |ctx: RpcContext<'_>, payload: Option<MessageReader>| {
96 execute_unary(ctx, ser, de, payload.unwrap(), &mut handler)
97 };
98 let ch = Box::new(Handler::new(MethodType::Unary, h));
99 self.handlers.insert(method.name.as_bytes(), ch);
100 self
101 }
102
103 /// Add a client streaming RPC call handler.
add_client_streaming_handler<Req, Resp, F>( mut self, method: &Method<Req, Resp>, mut handler: F, ) -> ServiceBuilder where Req: 'static, Resp: 'static, F: FnMut(RpcContext<'_>, RequestStream<Req>, ClientStreamingSink<Resp>) + Send + Clone + 'static,104 pub fn add_client_streaming_handler<Req, Resp, F>(
105 mut self,
106 method: &Method<Req, Resp>,
107 mut handler: F,
108 ) -> ServiceBuilder
109 where
110 Req: 'static,
111 Resp: 'static,
112 F: FnMut(RpcContext<'_>, RequestStream<Req>, ClientStreamingSink<Resp>)
113 + Send
114 + Clone
115 + 'static,
116 {
117 let (ser, de) = (method.resp_ser(), method.req_de());
118 let h = move |ctx: RpcContext<'_>, _: Option<MessageReader>| {
119 execute_client_streaming(ctx, ser, de, &mut handler)
120 };
121 let ch = Box::new(Handler::new(MethodType::ClientStreaming, h));
122 self.handlers.insert(method.name.as_bytes(), ch);
123 self
124 }
125
126 /// Add a server streaming RPC call handler.
add_server_streaming_handler<Req, Resp, F>( mut self, method: &Method<Req, Resp>, mut handler: F, ) -> ServiceBuilder where Req: 'static, Resp: 'static, F: FnMut(RpcContext<'_>, Req, ServerStreamingSink<Resp>) + Send + Clone + 'static,127 pub fn add_server_streaming_handler<Req, Resp, F>(
128 mut self,
129 method: &Method<Req, Resp>,
130 mut handler: F,
131 ) -> ServiceBuilder
132 where
133 Req: 'static,
134 Resp: 'static,
135 F: FnMut(RpcContext<'_>, Req, ServerStreamingSink<Resp>) + Send + Clone + 'static,
136 {
137 let (ser, de) = (method.resp_ser(), method.req_de());
138 let h = move |ctx: RpcContext<'_>, payload: Option<MessageReader>| {
139 execute_server_streaming(ctx, ser, de, payload.unwrap(), &mut handler)
140 };
141 let ch = Box::new(Handler::new(MethodType::ServerStreaming, h));
142 self.handlers.insert(method.name.as_bytes(), ch);
143 self
144 }
145
146 /// Add a duplex streaming RPC call handler.
add_duplex_streaming_handler<Req, Resp, F>( mut self, method: &Method<Req, Resp>, mut handler: F, ) -> ServiceBuilder where Req: 'static, Resp: 'static, F: FnMut(RpcContext<'_>, RequestStream<Req>, DuplexSink<Resp>) + Send + Clone + 'static,147 pub fn add_duplex_streaming_handler<Req, Resp, F>(
148 mut self,
149 method: &Method<Req, Resp>,
150 mut handler: F,
151 ) -> ServiceBuilder
152 where
153 Req: 'static,
154 Resp: 'static,
155 F: FnMut(RpcContext<'_>, RequestStream<Req>, DuplexSink<Resp>) + Send + Clone + 'static,
156 {
157 let (ser, de) = (method.resp_ser(), method.req_de());
158 let h = move |ctx: RpcContext<'_>, _: Option<MessageReader>| {
159 execute_duplex_streaming(ctx, ser, de, &mut handler)
160 };
161 let ch = Box::new(Handler::new(MethodType::Duplex, h));
162 self.handlers.insert(method.name.as_bytes(), ch);
163 self
164 }
165
166 /// Finalize the [`ServiceBuilder`] and build the [`Service`].
build(self) -> Service167 pub fn build(self) -> Service {
168 Service {
169 handlers: self.handlers,
170 }
171 }
172 }
173
174 /// Used to indicate the result of the check. If it returns `Abort`,
175 /// skip the subsequent checkers and abort the grpc call.
176 pub enum CheckResult {
177 Continue,
178 Abort(RpcStatus),
179 }
180
181 pub trait ServerChecker: Send {
check(&mut self, ctx: &RpcContext) -> CheckResult182 fn check(&mut self, ctx: &RpcContext) -> CheckResult;
box_clone(&self) -> Box<dyn ServerChecker>183 fn box_clone(&self) -> Box<dyn ServerChecker>;
184 }
185
186 impl Clone for Box<dyn ServerChecker> {
clone(&self) -> Self187 fn clone(&self) -> Self {
188 self.box_clone()
189 }
190 }
191
192 /// A gRPC service.
193 ///
194 /// Use [`ServiceBuilder`] to build a [`Service`].
195 pub struct Service {
196 handlers: HashMap<&'static [u8], BoxHandler>,
197 }
198
199 /// [`Server`] factory in order to configure the properties.
200 pub struct ServerBuilder {
201 env: Arc<Environment>,
202 args: Option<ChannelArgs>,
203 slots_per_cq: usize,
204 handlers: HashMap<&'static [u8], BoxHandler>,
205 checkers: Vec<Box<dyn ServerChecker>>,
206 }
207
208 impl ServerBuilder {
209 /// Initialize a new [`ServerBuilder`].
new(env: Arc<Environment>) -> ServerBuilder210 pub fn new(env: Arc<Environment>) -> ServerBuilder {
211 ServerBuilder {
212 env,
213 args: None,
214 slots_per_cq: DEFAULT_REQUEST_SLOTS_PER_CQ,
215 handlers: HashMap::new(),
216 checkers: Vec::new(),
217 }
218 }
219
220 /// Add additional configuration for each incoming channel.
channel_args(mut self, args: ChannelArgs) -> ServerBuilder221 pub fn channel_args(mut self, args: ChannelArgs) -> ServerBuilder {
222 self.args = Some(args);
223 self
224 }
225
226 /// Set how many requests a completion queue can handle.
requests_slot_per_cq(mut self, slots: usize) -> ServerBuilder227 pub fn requests_slot_per_cq(mut self, slots: usize) -> ServerBuilder {
228 self.slots_per_cq = slots;
229 self
230 }
231
232 /// Register a service.
register_service(mut self, service: Service) -> ServerBuilder233 pub fn register_service(mut self, service: Service) -> ServerBuilder {
234 self.handlers.extend(service.handlers);
235 self
236 }
237
238 /// Add a custom checker to handle some tasks before the grpc call handler starts.
239 /// This allows users to operate grpc call based on the context. Users can add
240 /// multiple checkers and they will be executed in the order added.
241 ///
242 /// TODO: Extend this interface to intercepte each payload like grpc-c++.
add_checker<C: ServerChecker + 'static>(mut self, checker: C) -> ServerBuilder243 pub fn add_checker<C: ServerChecker + 'static>(mut self, checker: C) -> ServerBuilder {
244 self.checkers.push(Box::new(checker));
245 self
246 }
247
248 /// Finalize the [`ServerBuilder`] and build the [`Server`].
build(self) -> Result<Server>249 pub fn build(self) -> Result<Server> {
250 let args = self
251 .args
252 .as_ref()
253 .map_or_else(ptr::null, ChannelArgs::as_ptr);
254 unsafe {
255 let server = grpc_sys::grpc_server_create(args, ptr::null_mut());
256 for cq in self.env.completion_queues() {
257 let cq_ref = cq.borrow()?;
258 grpc_sys::grpc_server_register_completion_queue(
259 server,
260 cq_ref.as_ptr(),
261 ptr::null_mut(),
262 );
263 }
264
265 Ok(Server {
266 env: self.env,
267 core: Arc::new(ServerCore {
268 server,
269 creds: Mutex::new(Vec::new()),
270 shutdown: AtomicBool::new(false),
271 slots_per_cq: self.slots_per_cq,
272 }),
273 handlers: self.handlers,
274 checkers: self.checkers,
275 })
276 }
277 }
278 }
279
280 struct ServerCore {
281 server: *mut grpc_server,
282 creds: Mutex<Vec<ServerCredentials>>,
283 slots_per_cq: usize,
284 shutdown: AtomicBool,
285 }
286
287 impl Drop for ServerCore {
drop(&mut self)288 fn drop(&mut self) {
289 unsafe { grpc_sys::grpc_server_destroy(self.server) }
290 }
291 }
292
293 unsafe impl Send for ServerCore {}
294 unsafe impl Sync for ServerCore {}
295
296 pub type BoxHandler = Box<dyn CloneableHandler>;
297
298 #[derive(Clone)]
299 pub struct RequestCallContext {
300 server: Arc<ServerCore>,
301 registry: Arc<UnsafeCell<HashMap<&'static [u8], BoxHandler>>>,
302 checkers: Vec<Box<dyn ServerChecker>>,
303 }
304
305 impl RequestCallContext {
306 /// Users should guarantee the method is always called from the same thread.
307 /// TODO: Is there a better way?
308 #[inline]
get_handler(&mut self, path: &[u8]) -> Option<&mut BoxHandler>309 pub unsafe fn get_handler(&mut self, path: &[u8]) -> Option<&mut BoxHandler> {
310 let registry = &mut *self.registry.get();
311 registry.get_mut(path)
312 }
313
get_checker(&self) -> Vec<Box<dyn ServerChecker>>314 pub(crate) fn get_checker(&self) -> Vec<Box<dyn ServerChecker>> {
315 self.checkers.clone()
316 }
317 }
318
319 // Apparently, its life time is guaranteed by the ref count, hence is safe to be sent
320 // to other thread. However it's not `Sync`, as `BoxHandler` is unnecessarily `Sync`.
321 #[allow(clippy::non_send_fields_in_send_ty)]
322 unsafe impl Send for RequestCallContext {}
323
324 /// Request notification of a new call.
request_call(ctx: RequestCallContext, cq: &CompletionQueue)325 pub fn request_call(ctx: RequestCallContext, cq: &CompletionQueue) {
326 if ctx.server.shutdown.load(Ordering::Relaxed) {
327 return;
328 }
329 let cq_ref = match cq.borrow() {
330 // Shutting down, skip.
331 Err(_) => return,
332 Ok(c) => c,
333 };
334 let server_ptr = ctx.server.server;
335 let prom = CallTag::request(ctx);
336 let request_ptr = prom.request_ctx().unwrap().as_ptr();
337 let prom_box = Box::new(prom);
338 let tag = Box::into_raw(prom_box);
339 let code = unsafe {
340 grpc_sys::grpcwrap_server_request_call(
341 server_ptr,
342 cq_ref.as_ptr(),
343 request_ptr,
344 tag as *mut _,
345 )
346 };
347 if code != grpc_call_error::GRPC_CALL_OK {
348 drop(Box::from(tag));
349 panic!("failed to request call: {:?}", code);
350 }
351 }
352
353 /// A `Future` that will resolve when shutdown completes.
354 pub struct ShutdownFuture {
355 /// `true` means the future finishes successfully.
356 cq_f: CqFuture<bool>,
357 }
358
359 impl Future for ShutdownFuture {
360 type Output = Result<()>;
361
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>362 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
363 match ready!(Pin::new(&mut self.cq_f).poll(cx)) {
364 Ok(true) => Poll::Ready(Ok(())),
365 Ok(false) => Poll::Ready(Err(Error::ShutdownFailed)),
366 Err(e) => unreachable!("action future should never resolve to error: {}", e),
367 }
368 }
369 }
370
371 /// A gRPC server.
372 ///
373 /// A single server can serve arbitrary number of services and can listen on more than one port.
374 ///
375 /// Use [`ServerBuilder`] to build a [`Server`].
376 pub struct Server {
377 env: Arc<Environment>,
378 core: Arc<ServerCore>,
379 handlers: HashMap<&'static [u8], BoxHandler>,
380 checkers: Vec<Box<dyn ServerChecker>>,
381 }
382
383 impl Server {
384 /// Shutdown the server asynchronously.
shutdown(&mut self) -> ShutdownFuture385 pub fn shutdown(&mut self) -> ShutdownFuture {
386 let (cq_f, prom) = CallTag::action_pair();
387 let prom_box = Box::new(prom);
388 let tag = Box::into_raw(prom_box);
389 unsafe {
390 // Since env still exists, no way can cq been shutdown.
391 let cq_ref = self.env.completion_queues()[0].borrow().unwrap();
392 grpc_sys::grpc_server_shutdown_and_notify(
393 self.core.server,
394 cq_ref.as_ptr(),
395 tag as *mut _,
396 )
397 }
398 self.core.shutdown.store(true, Ordering::SeqCst);
399 ShutdownFuture { cq_f }
400 }
401
402 /// Cancel all in-progress calls.
403 ///
404 /// Only usable after shutdown.
cancel_all_calls(&mut self)405 pub fn cancel_all_calls(&mut self) {
406 unsafe { grpc_sys::grpc_server_cancel_all_calls(self.core.server) }
407 }
408
409 /// Start the server.
start(&mut self)410 pub fn start(&mut self) {
411 unsafe {
412 grpc_sys::grpc_server_start(self.core.server);
413 for cq in self.env.completion_queues() {
414 // Handlers are Send and Clone, but not Sync. So we need to
415 // provide a replica for each completion queue.
416 let registry = self
417 .handlers
418 .iter()
419 .map(|(k, v)| (k.to_owned(), v.box_clone()))
420 .collect();
421 let rc = RequestCallContext {
422 server: self.core.clone(),
423 registry: Arc::new(UnsafeCell::new(registry)),
424 checkers: self.checkers.clone(),
425 };
426 for _ in 0..self.core.slots_per_cq {
427 request_call(rc.clone(), cq);
428 }
429 }
430 }
431 }
432
433 /// Try binding the server to the given `addr` endpoint (eg, localhost:1234,
434 /// 192.168.1.1:31416, [::1]:27182, etc.).
435 ///
436 /// It can be invoked multiple times. Should be used before starting the server.
437 ///
438 /// # Return
439 ///
440 /// The bound port is returned on success.
add_listening_port( &mut self, addr: impl Into<String>, mut creds: ServerCredentials, ) -> Result<u16>441 pub fn add_listening_port(
442 &mut self,
443 addr: impl Into<String>,
444 mut creds: ServerCredentials,
445 ) -> Result<u16> {
446 // There is no Null in UTF-8 string.
447 let addr = CString::new(addr.into()).unwrap();
448 let port = unsafe {
449 grpcio_sys::grpc_server_add_http2_port(
450 self.core.server,
451 addr.as_ptr() as _,
452 creds.as_mut_ptr(),
453 ) as u16
454 };
455 if port != 0 {
456 self.core.creds.lock().unwrap().push(creds);
457 Ok(port)
458 } else {
459 Err(Error::BindFail(addr))
460 }
461 }
462
463 /// Add an rpc channel for an established connection represented as a file
464 /// descriptor. Takes ownership of the file descriptor, closing it when
465 /// channel is closed.
466 ///
467 /// # Safety
468 ///
469 /// The file descriptor must correspond to a connected stream socket. After
470 /// this call, the socket must not be accessed (read / written / closed)
471 /// by other code.
472 #[cfg(unix)]
add_channel_from_fd(&mut self, fd: ::std::os::raw::c_int)473 pub unsafe fn add_channel_from_fd(&mut self, fd: ::std::os::raw::c_int) {
474 let mut creds = ServerCredentials::insecure();
475 grpcio_sys::grpc_server_add_channel_from_fd(self.core.server, fd, creds.as_mut_ptr())
476 }
477 }
478
479 impl Drop for Server {
drop(&mut self)480 fn drop(&mut self) {
481 // if the server is not shutdown completely, destroy a server will core.
482 // TODO: don't wait here
483 let f = if !self.core.shutdown.load(Ordering::SeqCst) {
484 Some(self.shutdown())
485 } else {
486 None
487 };
488 self.cancel_all_calls();
489 let _ = f.map(futures_executor::block_on);
490 }
491 }
492
493 impl Debug for Server {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result494 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
495 write!(
496 f,
497 "Server {{ handlers: {}, checkers: {} }}",
498 self.handlers.len(),
499 self.checkers.len()
500 )
501 }
502 }
503