1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2 
3 use std::cell::UnsafeCell;
4 use std::collections::HashMap;
5 use std::ffi::CString;
6 use std::fmt::{self, Debug, Formatter};
7 use std::future::Future;
8 use std::pin::Pin;
9 use std::ptr;
10 use std::sync::atomic::{AtomicBool, Ordering};
11 use std::sync::{Arc, Mutex};
12 use std::task::{Context, Poll};
13 
14 use crate::grpc_sys::{self, grpc_call_error, grpc_server};
15 use futures_util::ready;
16 
17 use crate::call::server::*;
18 use crate::call::{MessageReader, Method, MethodType};
19 use crate::channel::ChannelArgs;
20 use crate::cq::CompletionQueue;
21 use crate::env::Environment;
22 use crate::error::{Error, Result};
23 use crate::task::{CallTag, CqFuture};
24 use crate::RpcStatus;
25 use crate::{RpcContext, ServerCredentials};
26 
27 const DEFAULT_REQUEST_SLOTS_PER_CQ: usize = 1024;
28 
29 /// An RPC call holder.
30 #[derive(Clone)]
31 pub struct Handler<F> {
32     method_type: MethodType,
33     cb: F,
34 }
35 
36 impl<F> Handler<F> {
new(method_type: MethodType, cb: F) -> Handler<F>37     pub fn new(method_type: MethodType, cb: F) -> Handler<F> {
38         Handler { method_type, cb }
39     }
40 }
41 
42 pub trait CloneableHandler: Send {
handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>)43     fn handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>);
box_clone(&self) -> Box<dyn CloneableHandler>44     fn box_clone(&self) -> Box<dyn CloneableHandler>;
method_type(&self) -> MethodType45     fn method_type(&self) -> MethodType;
46 }
47 
48 impl<F: 'static> CloneableHandler for Handler<F>
49 where
50     F: FnMut(RpcContext<'_>, Option<MessageReader>) + Send + Clone,
51 {
52     #[inline]
handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>)53     fn handle(&mut self, ctx: RpcContext<'_>, reqs: Option<MessageReader>) {
54         (self.cb)(ctx, reqs)
55     }
56 
57     #[inline]
box_clone(&self) -> Box<dyn CloneableHandler>58     fn box_clone(&self) -> Box<dyn CloneableHandler> {
59         Box::new(self.clone())
60     }
61 
62     #[inline]
method_type(&self) -> MethodType63     fn method_type(&self) -> MethodType {
64         self.method_type
65     }
66 }
67 
68 /// [`Service`] factory in order to configure the properties.
69 ///
70 /// Use it to build a service which can be registered to a server.
71 pub struct ServiceBuilder {
72     handlers: HashMap<&'static [u8], BoxHandler>,
73 }
74 
75 impl ServiceBuilder {
76     /// Initialize a new [`ServiceBuilder`].
new() -> ServiceBuilder77     pub fn new() -> ServiceBuilder {
78         ServiceBuilder {
79             handlers: HashMap::new(),
80         }
81     }
82 
83     /// Add a unary RPC call handler.
add_unary_handler<Req, Resp, F>( mut self, method: &Method<Req, Resp>, mut handler: F, ) -> ServiceBuilder where Req: 'static, Resp: 'static, F: FnMut(RpcContext<'_>, Req, UnarySink<Resp>) + Send + Clone + 'static,84     pub fn add_unary_handler<Req, Resp, F>(
85         mut self,
86         method: &Method<Req, Resp>,
87         mut handler: F,
88     ) -> ServiceBuilder
89     where
90         Req: 'static,
91         Resp: 'static,
92         F: FnMut(RpcContext<'_>, Req, UnarySink<Resp>) + Send + Clone + 'static,
93     {
94         let (ser, de) = (method.resp_ser(), method.req_de());
95         let h = move |ctx: RpcContext<'_>, payload: Option<MessageReader>| {
96             execute_unary(ctx, ser, de, payload.unwrap(), &mut handler)
97         };
98         let ch = Box::new(Handler::new(MethodType::Unary, h));
99         self.handlers.insert(method.name.as_bytes(), ch);
100         self
101     }
102 
103     /// Add a client streaming RPC call handler.
add_client_streaming_handler<Req, Resp, F>( mut self, method: &Method<Req, Resp>, mut handler: F, ) -> ServiceBuilder where Req: 'static, Resp: 'static, F: FnMut(RpcContext<'_>, RequestStream<Req>, ClientStreamingSink<Resp>) + Send + Clone + 'static,104     pub fn add_client_streaming_handler<Req, Resp, F>(
105         mut self,
106         method: &Method<Req, Resp>,
107         mut handler: F,
108     ) -> ServiceBuilder
109     where
110         Req: 'static,
111         Resp: 'static,
112         F: FnMut(RpcContext<'_>, RequestStream<Req>, ClientStreamingSink<Resp>)
113             + Send
114             + Clone
115             + 'static,
116     {
117         let (ser, de) = (method.resp_ser(), method.req_de());
118         let h = move |ctx: RpcContext<'_>, _: Option<MessageReader>| {
119             execute_client_streaming(ctx, ser, de, &mut handler)
120         };
121         let ch = Box::new(Handler::new(MethodType::ClientStreaming, h));
122         self.handlers.insert(method.name.as_bytes(), ch);
123         self
124     }
125 
126     /// Add a server streaming RPC call handler.
add_server_streaming_handler<Req, Resp, F>( mut self, method: &Method<Req, Resp>, mut handler: F, ) -> ServiceBuilder where Req: 'static, Resp: 'static, F: FnMut(RpcContext<'_>, Req, ServerStreamingSink<Resp>) + Send + Clone + 'static,127     pub fn add_server_streaming_handler<Req, Resp, F>(
128         mut self,
129         method: &Method<Req, Resp>,
130         mut handler: F,
131     ) -> ServiceBuilder
132     where
133         Req: 'static,
134         Resp: 'static,
135         F: FnMut(RpcContext<'_>, Req, ServerStreamingSink<Resp>) + Send + Clone + 'static,
136     {
137         let (ser, de) = (method.resp_ser(), method.req_de());
138         let h = move |ctx: RpcContext<'_>, payload: Option<MessageReader>| {
139             execute_server_streaming(ctx, ser, de, payload.unwrap(), &mut handler)
140         };
141         let ch = Box::new(Handler::new(MethodType::ServerStreaming, h));
142         self.handlers.insert(method.name.as_bytes(), ch);
143         self
144     }
145 
146     /// Add a duplex streaming RPC call handler.
add_duplex_streaming_handler<Req, Resp, F>( mut self, method: &Method<Req, Resp>, mut handler: F, ) -> ServiceBuilder where Req: 'static, Resp: 'static, F: FnMut(RpcContext<'_>, RequestStream<Req>, DuplexSink<Resp>) + Send + Clone + 'static,147     pub fn add_duplex_streaming_handler<Req, Resp, F>(
148         mut self,
149         method: &Method<Req, Resp>,
150         mut handler: F,
151     ) -> ServiceBuilder
152     where
153         Req: 'static,
154         Resp: 'static,
155         F: FnMut(RpcContext<'_>, RequestStream<Req>, DuplexSink<Resp>) + Send + Clone + 'static,
156     {
157         let (ser, de) = (method.resp_ser(), method.req_de());
158         let h = move |ctx: RpcContext<'_>, _: Option<MessageReader>| {
159             execute_duplex_streaming(ctx, ser, de, &mut handler)
160         };
161         let ch = Box::new(Handler::new(MethodType::Duplex, h));
162         self.handlers.insert(method.name.as_bytes(), ch);
163         self
164     }
165 
166     /// Finalize the [`ServiceBuilder`] and build the [`Service`].
build(self) -> Service167     pub fn build(self) -> Service {
168         Service {
169             handlers: self.handlers,
170         }
171     }
172 }
173 
174 /// Used to indicate the result of the check. If it returns `Abort`,
175 /// skip the subsequent checkers and abort the grpc call.
176 pub enum CheckResult {
177     Continue,
178     Abort(RpcStatus),
179 }
180 
181 pub trait ServerChecker: Send {
check(&mut self, ctx: &RpcContext) -> CheckResult182     fn check(&mut self, ctx: &RpcContext) -> CheckResult;
box_clone(&self) -> Box<dyn ServerChecker>183     fn box_clone(&self) -> Box<dyn ServerChecker>;
184 }
185 
186 impl Clone for Box<dyn ServerChecker> {
clone(&self) -> Self187     fn clone(&self) -> Self {
188         self.box_clone()
189     }
190 }
191 
192 /// A gRPC service.
193 ///
194 /// Use [`ServiceBuilder`] to build a [`Service`].
195 pub struct Service {
196     handlers: HashMap<&'static [u8], BoxHandler>,
197 }
198 
199 /// [`Server`] factory in order to configure the properties.
200 pub struct ServerBuilder {
201     env: Arc<Environment>,
202     args: Option<ChannelArgs>,
203     slots_per_cq: usize,
204     handlers: HashMap<&'static [u8], BoxHandler>,
205     checkers: Vec<Box<dyn ServerChecker>>,
206 }
207 
208 impl ServerBuilder {
209     /// Initialize a new [`ServerBuilder`].
new(env: Arc<Environment>) -> ServerBuilder210     pub fn new(env: Arc<Environment>) -> ServerBuilder {
211         ServerBuilder {
212             env,
213             args: None,
214             slots_per_cq: DEFAULT_REQUEST_SLOTS_PER_CQ,
215             handlers: HashMap::new(),
216             checkers: Vec::new(),
217         }
218     }
219 
220     /// Add additional configuration for each incoming channel.
channel_args(mut self, args: ChannelArgs) -> ServerBuilder221     pub fn channel_args(mut self, args: ChannelArgs) -> ServerBuilder {
222         self.args = Some(args);
223         self
224     }
225 
226     /// Set how many requests a completion queue can handle.
requests_slot_per_cq(mut self, slots: usize) -> ServerBuilder227     pub fn requests_slot_per_cq(mut self, slots: usize) -> ServerBuilder {
228         self.slots_per_cq = slots;
229         self
230     }
231 
232     /// Register a service.
register_service(mut self, service: Service) -> ServerBuilder233     pub fn register_service(mut self, service: Service) -> ServerBuilder {
234         self.handlers.extend(service.handlers);
235         self
236     }
237 
238     /// Add a custom checker to handle some tasks before the grpc call handler starts.
239     /// This allows users to operate grpc call based on the context. Users can add
240     /// multiple checkers and they will be executed in the order added.
241     ///
242     /// TODO: Extend this interface to intercepte each payload like grpc-c++.
add_checker<C: ServerChecker + 'static>(mut self, checker: C) -> ServerBuilder243     pub fn add_checker<C: ServerChecker + 'static>(mut self, checker: C) -> ServerBuilder {
244         self.checkers.push(Box::new(checker));
245         self
246     }
247 
248     /// Finalize the [`ServerBuilder`] and build the [`Server`].
build(self) -> Result<Server>249     pub fn build(self) -> Result<Server> {
250         let args = self
251             .args
252             .as_ref()
253             .map_or_else(ptr::null, ChannelArgs::as_ptr);
254         unsafe {
255             let server = grpc_sys::grpc_server_create(args, ptr::null_mut());
256             for cq in self.env.completion_queues() {
257                 let cq_ref = cq.borrow()?;
258                 grpc_sys::grpc_server_register_completion_queue(
259                     server,
260                     cq_ref.as_ptr(),
261                     ptr::null_mut(),
262                 );
263             }
264 
265             Ok(Server {
266                 env: self.env,
267                 core: Arc::new(ServerCore {
268                     server,
269                     creds: Mutex::new(Vec::new()),
270                     shutdown: AtomicBool::new(false),
271                     slots_per_cq: self.slots_per_cq,
272                 }),
273                 handlers: self.handlers,
274                 checkers: self.checkers,
275             })
276         }
277     }
278 }
279 
280 struct ServerCore {
281     server: *mut grpc_server,
282     creds: Mutex<Vec<ServerCredentials>>,
283     slots_per_cq: usize,
284     shutdown: AtomicBool,
285 }
286 
287 impl Drop for ServerCore {
drop(&mut self)288     fn drop(&mut self) {
289         unsafe { grpc_sys::grpc_server_destroy(self.server) }
290     }
291 }
292 
293 unsafe impl Send for ServerCore {}
294 unsafe impl Sync for ServerCore {}
295 
296 pub type BoxHandler = Box<dyn CloneableHandler>;
297 
298 #[derive(Clone)]
299 pub struct RequestCallContext {
300     server: Arc<ServerCore>,
301     registry: Arc<UnsafeCell<HashMap<&'static [u8], BoxHandler>>>,
302     checkers: Vec<Box<dyn ServerChecker>>,
303 }
304 
305 impl RequestCallContext {
306     /// Users should guarantee the method is always called from the same thread.
307     /// TODO: Is there a better way?
308     #[inline]
get_handler(&mut self, path: &[u8]) -> Option<&mut BoxHandler>309     pub unsafe fn get_handler(&mut self, path: &[u8]) -> Option<&mut BoxHandler> {
310         let registry = &mut *self.registry.get();
311         registry.get_mut(path)
312     }
313 
get_checker(&self) -> Vec<Box<dyn ServerChecker>>314     pub(crate) fn get_checker(&self) -> Vec<Box<dyn ServerChecker>> {
315         self.checkers.clone()
316     }
317 }
318 
319 // Apparently, its life time is guaranteed by the ref count, hence is safe to be sent
320 // to other thread. However it's not `Sync`, as `BoxHandler` is unnecessarily `Sync`.
321 #[allow(clippy::non_send_fields_in_send_ty)]
322 unsafe impl Send for RequestCallContext {}
323 
324 /// Request notification of a new call.
request_call(ctx: RequestCallContext, cq: &CompletionQueue)325 pub fn request_call(ctx: RequestCallContext, cq: &CompletionQueue) {
326     if ctx.server.shutdown.load(Ordering::Relaxed) {
327         return;
328     }
329     let cq_ref = match cq.borrow() {
330         // Shutting down, skip.
331         Err(_) => return,
332         Ok(c) => c,
333     };
334     let server_ptr = ctx.server.server;
335     let prom = CallTag::request(ctx);
336     let request_ptr = prom.request_ctx().unwrap().as_ptr();
337     let prom_box = Box::new(prom);
338     let tag = Box::into_raw(prom_box);
339     let code = unsafe {
340         grpc_sys::grpcwrap_server_request_call(
341             server_ptr,
342             cq_ref.as_ptr(),
343             request_ptr,
344             tag as *mut _,
345         )
346     };
347     if code != grpc_call_error::GRPC_CALL_OK {
348         drop(Box::from(tag));
349         panic!("failed to request call: {:?}", code);
350     }
351 }
352 
353 /// A `Future` that will resolve when shutdown completes.
354 pub struct ShutdownFuture {
355     /// `true` means the future finishes successfully.
356     cq_f: CqFuture<bool>,
357 }
358 
359 impl Future for ShutdownFuture {
360     type Output = Result<()>;
361 
poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output>362     fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
363         match ready!(Pin::new(&mut self.cq_f).poll(cx)) {
364             Ok(true) => Poll::Ready(Ok(())),
365             Ok(false) => Poll::Ready(Err(Error::ShutdownFailed)),
366             Err(e) => unreachable!("action future should never resolve to error: {}", e),
367         }
368     }
369 }
370 
371 /// A gRPC server.
372 ///
373 /// A single server can serve arbitrary number of services and can listen on more than one port.
374 ///
375 /// Use [`ServerBuilder`] to build a [`Server`].
376 pub struct Server {
377     env: Arc<Environment>,
378     core: Arc<ServerCore>,
379     handlers: HashMap<&'static [u8], BoxHandler>,
380     checkers: Vec<Box<dyn ServerChecker>>,
381 }
382 
383 impl Server {
384     /// Shutdown the server asynchronously.
shutdown(&mut self) -> ShutdownFuture385     pub fn shutdown(&mut self) -> ShutdownFuture {
386         let (cq_f, prom) = CallTag::action_pair();
387         let prom_box = Box::new(prom);
388         let tag = Box::into_raw(prom_box);
389         unsafe {
390             // Since env still exists, no way can cq been shutdown.
391             let cq_ref = self.env.completion_queues()[0].borrow().unwrap();
392             grpc_sys::grpc_server_shutdown_and_notify(
393                 self.core.server,
394                 cq_ref.as_ptr(),
395                 tag as *mut _,
396             )
397         }
398         self.core.shutdown.store(true, Ordering::SeqCst);
399         ShutdownFuture { cq_f }
400     }
401 
402     /// Cancel all in-progress calls.
403     ///
404     /// Only usable after shutdown.
cancel_all_calls(&mut self)405     pub fn cancel_all_calls(&mut self) {
406         unsafe { grpc_sys::grpc_server_cancel_all_calls(self.core.server) }
407     }
408 
409     /// Start the server.
start(&mut self)410     pub fn start(&mut self) {
411         unsafe {
412             grpc_sys::grpc_server_start(self.core.server);
413             for cq in self.env.completion_queues() {
414                 // Handlers are Send and Clone, but not Sync. So we need to
415                 // provide a replica for each completion queue.
416                 let registry = self
417                     .handlers
418                     .iter()
419                     .map(|(k, v)| (k.to_owned(), v.box_clone()))
420                     .collect();
421                 let rc = RequestCallContext {
422                     server: self.core.clone(),
423                     registry: Arc::new(UnsafeCell::new(registry)),
424                     checkers: self.checkers.clone(),
425                 };
426                 for _ in 0..self.core.slots_per_cq {
427                     request_call(rc.clone(), cq);
428                 }
429             }
430         }
431     }
432 
433     /// Try binding the server to the given `addr` endpoint (eg, localhost:1234,
434     /// 192.168.1.1:31416, [::1]:27182, etc.).
435     ///
436     /// It can be invoked multiple times. Should be used before starting the server.
437     ///
438     /// # Return
439     ///
440     /// The bound port is returned on success.
add_listening_port( &mut self, addr: impl Into<String>, mut creds: ServerCredentials, ) -> Result<u16>441     pub fn add_listening_port(
442         &mut self,
443         addr: impl Into<String>,
444         mut creds: ServerCredentials,
445     ) -> Result<u16> {
446         // There is no Null in UTF-8 string.
447         let addr = CString::new(addr.into()).unwrap();
448         let port = unsafe {
449             grpcio_sys::grpc_server_add_http2_port(
450                 self.core.server,
451                 addr.as_ptr() as _,
452                 creds.as_mut_ptr(),
453             ) as u16
454         };
455         if port != 0 {
456             self.core.creds.lock().unwrap().push(creds);
457             Ok(port)
458         } else {
459             Err(Error::BindFail(addr))
460         }
461     }
462 
463     /// Add an rpc channel for an established connection represented as a file
464     /// descriptor. Takes ownership of the file descriptor, closing it when
465     /// channel is closed.
466     ///
467     /// # Safety
468     ///
469     /// The file descriptor must correspond to a connected stream socket. After
470     /// this call, the socket must not be accessed (read / written / closed)
471     /// by other code.
472     #[cfg(unix)]
add_channel_from_fd(&mut self, fd: ::std::os::raw::c_int)473     pub unsafe fn add_channel_from_fd(&mut self, fd: ::std::os::raw::c_int) {
474         let mut creds = ServerCredentials::insecure();
475         grpcio_sys::grpc_server_add_channel_from_fd(self.core.server, fd, creds.as_mut_ptr())
476     }
477 }
478 
479 impl Drop for Server {
drop(&mut self)480     fn drop(&mut self) {
481         // if the server is not shutdown completely, destroy a server will core.
482         // TODO: don't wait here
483         let f = if !self.core.shutdown.load(Ordering::SeqCst) {
484             Some(self.shutdown())
485         } else {
486             None
487         };
488         self.cancel_all_calls();
489         let _ = f.map(futures_executor::block_on);
490     }
491 }
492 
493 impl Debug for Server {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result494     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
495         write!(
496             f,
497             "Server {{ handlers: {}, checkers: {} }}",
498             self.handlers.len(),
499             self.checkers.len()
500         )
501     }
502 }
503