1 // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
2 
3 pub mod client;
4 pub mod server;
5 
6 use std::ffi::CStr;
7 use std::fmt::{self, Debug, Display};
8 use std::future::Future;
9 use std::pin::Pin;
10 use std::sync::Arc;
11 use std::task::{Context, Poll};
12 use std::{ptr, slice};
13 
14 use crate::grpc_sys::{self, grpc_call, grpc_call_error, grpcwrap_batch_context};
15 use crate::metadata::UnownedMetadata;
16 use crate::{cq::CompletionQueue, Metadata, MetadataBuilder};
17 use futures_util::ready;
18 use libc::c_void;
19 use parking_lot::Mutex;
20 
21 use crate::buf::{GrpcByteBuffer, GrpcByteBufferReader, GrpcSlice};
22 use crate::codec::{DeserializeFn, Marshaller, SerializeFn};
23 use crate::error::{Error, Result};
24 use crate::grpc_sys::grpc_status_code::*;
25 use crate::task::{self, BatchFuture, BatchResult, BatchType, CallTag};
26 
27 /// An gRPC status code structure.
28 /// This type contains constants for all gRPC status codes.
29 #[derive(PartialEq, Eq, Clone, Copy)]
30 pub struct RpcStatusCode(i32);
31 
32 impl From<i32> for RpcStatusCode {
from(code: i32) -> RpcStatusCode33     fn from(code: i32) -> RpcStatusCode {
34         RpcStatusCode(code)
35     }
36 }
37 
38 impl From<RpcStatusCode> for i32 {
from(code: RpcStatusCode) -> i3239     fn from(code: RpcStatusCode) -> i32 {
40         code.0
41     }
42 }
43 
44 impl Display for RpcStatusCode {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result45     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46         Debug::fmt(self, f)
47     }
48 }
49 
50 macro_rules! status_codes {
51     (
52         $(
53             ($num:path, $konst:ident);
54         )+
55     ) => {
56         impl RpcStatusCode {
57         $(
58             pub const $konst: RpcStatusCode = RpcStatusCode($num);
59         )+
60         }
61 
62         impl Debug for RpcStatusCode {
63             fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64                 write!(
65                     f,
66                     "{}-{}",
67                     self.0,
68                     match self {
69                         $(RpcStatusCode($num) => stringify!($konst),)+
70                         RpcStatusCode(_) => "INVALID_STATUS_CODE",
71                     }
72                 )
73             }
74         }
75     }
76 }
77 
78 status_codes! {
79     (GRPC_STATUS_OK, OK);
80     (GRPC_STATUS_CANCELLED, CANCELLED);
81     (GRPC_STATUS_UNKNOWN, UNKNOWN);
82     (GRPC_STATUS_INVALID_ARGUMENT, INVALID_ARGUMENT);
83     (GRPC_STATUS_DEADLINE_EXCEEDED, DEADLINE_EXCEEDED);
84     (GRPC_STATUS_NOT_FOUND, NOT_FOUND);
85     (GRPC_STATUS_ALREADY_EXISTS, ALREADY_EXISTS);
86     (GRPC_STATUS_PERMISSION_DENIED, PERMISSION_DENIED);
87     (GRPC_STATUS_RESOURCE_EXHAUSTED, RESOURCE_EXHAUSTED);
88     (GRPC_STATUS_FAILED_PRECONDITION, FAILED_PRECONDITION);
89     (GRPC_STATUS_ABORTED, ABORTED);
90     (GRPC_STATUS_OUT_OF_RANGE, OUT_OF_RANGE);
91     (GRPC_STATUS_UNIMPLEMENTED, UNIMPLEMENTED);
92     (GRPC_STATUS_INTERNAL, INTERNAL);
93     (GRPC_STATUS_UNAVAILABLE, UNAVAILABLE);
94     (GRPC_STATUS_DATA_LOSS, DATA_LOSS);
95     (GRPC_STATUS_UNAUTHENTICATED, UNAUTHENTICATED);
96     (GRPC_STATUS__DO_NOT_USE, DO_NOT_USE);
97 }
98 
99 /// Method types supported by gRPC.
100 #[derive(Clone, Copy)]
101 pub enum MethodType {
102     /// Single request sent from client, single response received from server.
103     Unary,
104 
105     /// Stream of requests sent from client, single response received from server.
106     ClientStreaming,
107 
108     /// Single request sent from client, stream of responses received from server.
109     ServerStreaming,
110 
111     /// Both server and client can stream arbitrary number of requests and responses simultaneously.
112     Duplex,
113 }
114 
115 /// A description of a remote method.
116 // TODO: add serializer and deserializer.
117 pub struct Method<Req, Resp> {
118     /// Type of method.
119     pub ty: MethodType,
120 
121     /// Full qualified name of the method.
122     pub name: &'static str,
123 
124     /// The marshaller used for request messages.
125     pub req_mar: Marshaller<Req>,
126 
127     /// The marshaller used for response messages.
128     pub resp_mar: Marshaller<Resp>,
129 }
130 
131 impl<Req, Resp> Method<Req, Resp> {
132     /// Get the request serializer.
133     #[inline]
req_ser(&self) -> SerializeFn<Req>134     pub fn req_ser(&self) -> SerializeFn<Req> {
135         self.req_mar.ser
136     }
137 
138     /// Get the request deserializer.
139     #[inline]
req_de(&self) -> DeserializeFn<Req>140     pub fn req_de(&self) -> DeserializeFn<Req> {
141         self.req_mar.de
142     }
143 
144     /// Get the response serializer.
145     #[inline]
resp_ser(&self) -> SerializeFn<Resp>146     pub fn resp_ser(&self) -> SerializeFn<Resp> {
147         self.resp_mar.ser
148     }
149 
150     /// Get the response deserializer.
151     #[inline]
resp_de(&self) -> DeserializeFn<Resp>152     pub fn resp_de(&self) -> DeserializeFn<Resp> {
153         self.resp_mar.de
154     }
155 }
156 
157 /// RPC result returned from the server.
158 #[derive(Debug, Clone)]
159 pub struct RpcStatus {
160     /// gRPC status code. `Ok` indicates success, all other values indicate an error.
161     code: RpcStatusCode,
162 
163     /// error message.
164     message: String,
165 
166     /// Additional details for rich error model.
167     ///
168     /// See also https://grpc.io/docs/guides/error/#richer-error-model.
169     details: Vec<u8>,
170 
171     /// Debug error string
172     debug_error_string: String,
173 }
174 
175 impl Display for RpcStatus {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result176     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
177         Debug::fmt(self, fmt)
178     }
179 }
180 
181 impl RpcStatus {
182     /// Create a new [`RpcStatus`].
new<T: Into<RpcStatusCode>>(code: T) -> RpcStatus183     pub fn new<T: Into<RpcStatusCode>>(code: T) -> RpcStatus {
184         RpcStatus::with_message(code, String::new())
185     }
186 
187     /// Create a new [`RpcStatus`] with given message.
with_message<T: Into<RpcStatusCode>>(code: T, message: String) -> RpcStatus188     pub fn with_message<T: Into<RpcStatusCode>>(code: T, message: String) -> RpcStatus {
189         RpcStatus::with_details(code, message, vec![])
190     }
191 
192     /// Create a new [`RpcStats`] with code, message and details.
193     ///
194     /// If using rich error model, `details` should be binary message that sets `code` and
195     /// `message` to the same value. Or you can use `into` method to do automatic
196     /// transformation if using `grpcio_proto::google::rpc::Status`.
with_details<T: Into<RpcStatusCode>>( code: T, message: String, details: Vec<u8>, ) -> RpcStatus197     pub fn with_details<T: Into<RpcStatusCode>>(
198         code: T,
199         message: String,
200         details: Vec<u8>,
201     ) -> RpcStatus {
202         RpcStatus::with_details_and_error_string(code, message, details, String::new())
203     }
204 
205     /// Create a new [`RpcStats`] with code, message, details and debug error string.
206     ///
207     /// If using rich error model, `details` should be binary message that sets `code` and
208     /// `message` to the same value. Or you can use `into` method to do automatic
209     /// transformation if using `grpcio_proto::google::rpc::Status`.
with_details_and_error_string<T: Into<RpcStatusCode>>( code: T, message: String, details: Vec<u8>, debug_error_string: String, ) -> RpcStatus210     pub fn with_details_and_error_string<T: Into<RpcStatusCode>>(
211         code: T,
212         message: String,
213         details: Vec<u8>,
214         debug_error_string: String,
215     ) -> RpcStatus {
216         RpcStatus {
217             code: code.into(),
218             message,
219             details,
220             debug_error_string,
221         }
222     }
223 
224     /// Create a new [`RpcStatus`] that status code is Ok.
ok() -> RpcStatus225     pub fn ok() -> RpcStatus {
226         RpcStatus::new(RpcStatusCode::OK)
227     }
228 
229     /// Return the instance's error code.
230     #[inline]
code(&self) -> RpcStatusCode231     pub fn code(&self) -> RpcStatusCode {
232         self.code
233     }
234 
235     /// Return the instance's error message.
236     #[inline]
message(&self) -> &str237     pub fn message(&self) -> &str {
238         &self.message
239     }
240 
241     /// Return the (binary) error details.
242     ///
243     /// Usually it contains a serialized `google.rpc.Status` proto.
details(&self) -> &[u8]244     pub fn details(&self) -> &[u8] {
245         &self.details
246     }
247 
248     /// Return the debug error string.
249     ///
250     /// This will return a detailed string of the gRPC Core error that led to the failure.
251     /// It shouldn't be relied upon for anything other than gaining more debug data in
252     /// failure cases.
debug_error_string(&self) -> &str253     pub fn debug_error_string(&self) -> &str {
254         &self.debug_error_string
255     }
256 }
257 
258 pub type MessageReader = GrpcByteBufferReader;
259 
260 /// Context for batch request.
261 pub struct BatchContext {
262     ctx: *mut grpcwrap_batch_context,
263 }
264 
265 impl BatchContext {
new() -> BatchContext266     pub fn new() -> BatchContext {
267         BatchContext {
268             ctx: unsafe { grpc_sys::grpcwrap_batch_context_create() },
269         }
270     }
271 
as_ptr(&self) -> *mut grpcwrap_batch_context272     pub fn as_ptr(&self) -> *mut grpcwrap_batch_context {
273         self.ctx
274     }
275 
take_recv_message(&self) -> Option<GrpcByteBuffer>276     pub fn take_recv_message(&self) -> Option<GrpcByteBuffer> {
277         let ptr = unsafe { grpc_sys::grpcwrap_batch_context_take_recv_message(self.ctx) };
278         if ptr.is_null() {
279             None
280         } else {
281             Some(unsafe { GrpcByteBuffer::from_raw(ptr) })
282         }
283     }
284 
285     /// Get the status of the rpc call.
rpc_status(&self) -> RpcStatus286     pub fn rpc_status(&self) -> RpcStatus {
287         let status = RpcStatusCode(unsafe {
288             grpc_sys::grpcwrap_batch_context_recv_status_on_client_status(self.ctx)
289         });
290 
291         if status == RpcStatusCode::OK {
292             RpcStatus::ok()
293         } else {
294             unsafe {
295                 let mut msg_len = 0;
296                 let details_ptr = grpc_sys::grpcwrap_batch_context_recv_status_on_client_details(
297                     self.ctx,
298                     &mut msg_len,
299                 );
300                 let msg_slice = slice::from_raw_parts(details_ptr as *const _, msg_len);
301                 let message = String::from_utf8_lossy(msg_slice).into_owned();
302                 let m_ptr =
303                     grpc_sys::grpcwrap_batch_context_recv_status_on_client_trailing_metadata(
304                         self.ctx,
305                     );
306                 let metadata = &*(m_ptr as *const Metadata);
307                 let details = metadata.search_binary_error_details().to_vec();
308 
309                 let error_string_ptr =
310                     grpc_sys::grpcwrap_batch_context_recv_status_on_client_error_string(self.ctx);
311                 let error_string = if error_string_ptr.is_null() {
312                     String::new()
313                 } else {
314                     CStr::from_ptr(error_string_ptr)
315                         .to_string_lossy()
316                         .into_owned()
317                 };
318 
319                 RpcStatus::with_details_and_error_string(status, message, details, error_string)
320             }
321         }
322     }
323 
324     /// Fetch the response bytes of the rpc call.
recv_message(&mut self) -> Option<MessageReader>325     pub fn recv_message(&mut self) -> Option<MessageReader> {
326         let buf = self.take_recv_message()?;
327         Some(GrpcByteBufferReader::new(buf))
328     }
329 
330     /// Get the initial metadata from response.
331     ///
332     /// If initial metadata is not fetched or the method has been called, empty metadata will be
333     /// returned.
take_initial_metadata(&mut self) -> UnownedMetadata334     pub fn take_initial_metadata(&mut self) -> UnownedMetadata {
335         let mut res = UnownedMetadata::empty();
336         unsafe {
337             grpcio_sys::grpcwrap_batch_context_take_recv_initial_metadata(
338                 self.ctx,
339                 res.as_mut_ptr(),
340             );
341         }
342         res
343     }
344 
345     /// Get the trailing metadata from response.
346     ///
347     /// If trailing metadata is not fetched or the method has been called, empty metadata will be
348     /// returned.
take_trailing_metadata(&mut self) -> UnownedMetadata349     pub fn take_trailing_metadata(&mut self) -> UnownedMetadata {
350         let mut res = UnownedMetadata::empty();
351         unsafe {
352             grpc_sys::grpcwrap_batch_context_take_recv_status_on_client_trailing_metadata(
353                 self.ctx,
354                 res.as_mut_ptr(),
355             );
356         }
357         res
358     }
359 }
360 
361 impl Drop for BatchContext {
drop(&mut self)362     fn drop(&mut self) {
363         unsafe { grpc_sys::grpcwrap_batch_context_destroy(self.ctx) }
364     }
365 }
366 
367 #[inline]
box_batch_tag(tag: CallTag) -> (*mut grpcwrap_batch_context, *mut CallTag)368 fn box_batch_tag(tag: CallTag) -> (*mut grpcwrap_batch_context, *mut CallTag) {
369     let tag_box = Box::new(tag);
370     (
371         tag_box.batch_ctx().unwrap().as_ptr(),
372         Box::into_raw(tag_box),
373     )
374 }
375 
376 /// A helper function that runs the batch call and checks the result.
check_run<F>(bt: BatchType, f: F) -> BatchFuture where F: FnOnce(*mut grpcwrap_batch_context, *mut c_void) -> grpc_call_error,377 fn check_run<F>(bt: BatchType, f: F) -> BatchFuture
378 where
379     F: FnOnce(*mut grpcwrap_batch_context, *mut c_void) -> grpc_call_error,
380 {
381     let (cq_f, tag) = CallTag::batch_pair(bt);
382     let (batch_ptr, tag_ptr) = box_batch_tag(tag);
383     let code = f(batch_ptr, tag_ptr as *mut c_void);
384     if code != grpc_call_error::GRPC_CALL_OK {
385         unsafe {
386             drop(Box::from_raw(tag_ptr));
387         }
388         panic!("create call fail: {:?}", code);
389     }
390     cq_f
391 }
392 
393 /// A Call represents an RPC.
394 ///
395 /// When created, it is in a configuration state allowing properties to be
396 /// set until it is invoked. After invoke, the Call can have messages
397 /// written to it and read from it.
398 pub struct Call {
399     pub call: *mut grpc_call,
400     pub cq: CompletionQueue,
401 }
402 
403 unsafe impl Send for Call {}
404 
405 impl Call {
from_raw(call: *mut grpc_sys::grpc_call, cq: CompletionQueue) -> Call406     pub unsafe fn from_raw(call: *mut grpc_sys::grpc_call, cq: CompletionQueue) -> Call {
407         assert!(!call.is_null());
408         Call { call, cq }
409     }
410 
411     /// Send a message asynchronously.
start_send_message( &mut self, msg: &mut GrpcSlice, write_flags: u32, initial_metadata: Option<&mut Metadata>, call_flags: u32, ) -> Result<BatchFuture>412     pub fn start_send_message(
413         &mut self,
414         msg: &mut GrpcSlice,
415         write_flags: u32,
416         initial_metadata: Option<&mut Metadata>,
417         call_flags: u32,
418     ) -> Result<BatchFuture> {
419         let _cq_ref = self.cq.borrow()?;
420         let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
421             grpc_sys::grpcwrap_call_send_message(
422                 self.call,
423                 ctx,
424                 msg.as_mut_ptr(),
425                 write_flags,
426                 initial_metadata.map_or_else(ptr::null_mut, |m| m as *mut _ as _),
427                 call_flags,
428                 tag,
429             )
430         });
431         Ok(f)
432     }
433 
434     /// Finish the rpc call from client.
start_send_close_client(&mut self) -> Result<BatchFuture>435     pub fn start_send_close_client(&mut self) -> Result<BatchFuture> {
436         let _cq_ref = self.cq.borrow()?;
437         let f = check_run(BatchType::Finish, |_, tag| unsafe {
438             grpc_sys::grpcwrap_call_send_close_from_client(self.call, tag)
439         });
440         Ok(f)
441     }
442 
443     /// Receive a message asynchronously.
start_recv_message(&mut self) -> Result<BatchFuture>444     pub fn start_recv_message(&mut self) -> Result<BatchFuture> {
445         let _cq_ref = self.cq.borrow()?;
446         let f = check_run(BatchType::Read, |ctx, tag| unsafe {
447             grpc_sys::grpcwrap_call_recv_message(self.call, ctx, tag)
448         });
449         Ok(f)
450     }
451 
452     /// Start handling from server side.
453     ///
454     /// Future will finish once close is received by the server.
start_server_side(&mut self) -> Result<BatchFuture>455     pub fn start_server_side(&mut self) -> Result<BatchFuture> {
456         let _cq_ref = self.cq.borrow()?;
457         let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
458             grpc_sys::grpcwrap_call_start_serverside(self.call, ctx, tag)
459         });
460         Ok(f)
461     }
462 
463     /// Send a status from server.
start_send_status_from_server( &mut self, status: &RpcStatus, initial_metadata: &mut Option<Metadata>, call_flags: u32, send_empty_metadata: bool, payload: &mut Option<GrpcSlice>, write_flags: u32, ) -> Result<BatchFuture>464     pub fn start_send_status_from_server(
465         &mut self,
466         status: &RpcStatus,
467         initial_metadata: &mut Option<Metadata>,
468         call_flags: u32,
469         send_empty_metadata: bool,
470         payload: &mut Option<GrpcSlice>,
471         write_flags: u32,
472     ) -> Result<BatchFuture> {
473         let _cq_ref = self.cq.borrow()?;
474 
475         if initial_metadata.is_none() && send_empty_metadata {
476             initial_metadata.replace(MetadataBuilder::new().build());
477         }
478 
479         let f = check_run(BatchType::Finish, |ctx, tag| unsafe {
480             let (msg_ptr, msg_len) = if status.code() == RpcStatusCode::OK {
481                 (ptr::null(), 0)
482             } else {
483                 (status.message.as_ptr(), status.message.len())
484             };
485             let payload_p = match payload {
486                 Some(p) => p.as_mut_ptr(),
487                 None => ptr::null_mut(),
488             };
489             let mut trailing_metadata: Option<Metadata> = if status.details.is_empty() {
490                 None
491             } else {
492                 let mut builder = MetadataBuilder::new();
493                 builder.set_binary_error_details(&status.details);
494                 Some(builder.build())
495             };
496             grpc_sys::grpcwrap_call_send_status_from_server(
497                 self.call,
498                 ctx,
499                 status.code().into(),
500                 msg_ptr as _,
501                 msg_len,
502                 initial_metadata
503                     .as_mut()
504                     .map_or_else(ptr::null_mut, |m| m as *mut _ as _),
505                 call_flags,
506                 trailing_metadata
507                     .as_mut()
508                     .map_or_else(ptr::null_mut, |m| m as *mut _ as _),
509                 payload_p,
510                 write_flags,
511                 tag,
512             )
513         });
514         Ok(f)
515     }
516 
517     /// Abort an rpc call before handler is called.
abort(self, status: &RpcStatus)518     pub fn abort(self, status: &RpcStatus) {
519         match self.cq.borrow() {
520             // Queue is shutdown, ignore.
521             Err(Error::QueueShutdown) => return,
522             Err(e) => panic!("unexpected error when aborting call: {:?}", e),
523             _ => {}
524         }
525         let call_ptr = self.call;
526         let tag = CallTag::abort(self);
527         let (batch_ptr, tag_ptr) = box_batch_tag(tag);
528 
529         let code = unsafe {
530             let (msg_ptr, msg_len) = if status.code() == RpcStatusCode::OK {
531                 (ptr::null(), 0)
532             } else {
533                 (status.message.as_ptr(), status.message.len())
534             };
535             grpc_sys::grpcwrap_call_send_status_from_server(
536                 call_ptr,
537                 batch_ptr,
538                 status.code().into(),
539                 msg_ptr as _,
540                 msg_len,
541                 (&mut MetadataBuilder::new().build()) as *mut _ as _,
542                 0,
543                 ptr::null_mut(),
544                 ptr::null_mut(),
545                 0,
546                 tag_ptr as *mut c_void,
547             )
548         };
549         if code != grpc_call_error::GRPC_CALL_OK {
550             unsafe {
551                 drop(Box::from_raw(tag_ptr));
552             }
553             panic!("create call fail: {:?}", code);
554         }
555     }
556 
557     /// Cancel the rpc call by client.
cancel(&self)558     fn cancel(&self) {
559         match self.cq.borrow() {
560             // Queue is shutdown, ignore.
561             Err(Error::QueueShutdown) => return,
562             Err(e) => panic!("unexpected error when canceling call: {:?}", e),
563             _ => {}
564         }
565         unsafe {
566             grpc_sys::grpc_call_cancel(self.call, ptr::null_mut());
567         }
568     }
569 }
570 
571 impl Drop for Call {
drop(&mut self)572     fn drop(&mut self) {
573         unsafe { grpc_sys::grpc_call_unref(self.call) }
574     }
575 }
576 
577 /// A share object for client streaming and duplex streaming call.
578 ///
579 /// In both cases, receiver and sender can be polled in the same time,
580 /// hence we need to share the call in the both sides and abort the sink
581 /// once the call is canceled or finished early.
582 struct ShareCall {
583     call: Call,
584     close_f: BatchFuture,
585     finished: bool,
586     status: Option<RpcStatus>,
587 }
588 
589 impl ShareCall {
new(call: Call, close_f: BatchFuture) -> ShareCall590     fn new(call: Call, close_f: BatchFuture) -> ShareCall {
591         ShareCall {
592             call,
593             close_f,
594             finished: false,
595             status: None,
596         }
597     }
598 
599     /// Poll if the call is still alive.
600     ///
601     /// If the call is still running, will register a notification for its completion.
poll_finish(&mut self, cx: &mut Context) -> Poll<Result<BatchResult>>602     fn poll_finish(&mut self, cx: &mut Context) -> Poll<Result<BatchResult>> {
603         let res = match Pin::new(&mut self.close_f).poll(cx) {
604             Poll::Ready(Ok(reader)) => {
605                 self.status = Some(RpcStatus::ok());
606                 Poll::Ready(Ok(reader))
607             }
608             Poll::Pending => return Poll::Pending,
609             Poll::Ready(Err(Error::RpcFailure(status))) => {
610                 self.status = Some(status.clone());
611                 Poll::Ready(Err(Error::RpcFailure(status)))
612             }
613             res => res,
614         };
615 
616         self.finished = true;
617         res
618     }
619 
620     /// Check if the call is finished.
check_alive(&mut self) -> Result<()>621     fn check_alive(&mut self) -> Result<()> {
622         if self.finished {
623             // maybe can just take here.
624             return Err(Error::RpcFinished(self.status.clone()));
625         }
626 
627         task::check_alive(&self.close_f)
628     }
629 }
630 
631 /// A helper trait that allows executing function on the internal `ShareCall` struct.
632 trait ShareCallHolder {
call<R, F: FnOnce(&mut ShareCall) -> R>(&mut self, f: F) -> R633     fn call<R, F: FnOnce(&mut ShareCall) -> R>(&mut self, f: F) -> R;
634 }
635 
636 impl ShareCallHolder for ShareCall {
call<R, F: FnOnce(&mut ShareCall) -> R>(&mut self, f: F) -> R637     fn call<R, F: FnOnce(&mut ShareCall) -> R>(&mut self, f: F) -> R {
638         f(self)
639     }
640 }
641 
642 impl ShareCallHolder for Arc<Mutex<ShareCall>> {
call<R, F: FnOnce(&mut ShareCall) -> R>(&mut self, f: F) -> R643     fn call<R, F: FnOnce(&mut ShareCall) -> R>(&mut self, f: F) -> R {
644         let mut call = self.lock();
645         f(&mut call)
646     }
647 }
648 
649 /// A helper struct for constructing Stream object for batch requests.
650 struct StreamingBase {
651     close_f: Option<BatchFuture>,
652     msg_f: Option<BatchFuture>,
653     read_done: bool,
654 }
655 
656 impl StreamingBase {
new(close_f: Option<BatchFuture>) -> StreamingBase657     fn new(close_f: Option<BatchFuture>) -> StreamingBase {
658         StreamingBase {
659             close_f,
660             msg_f: None,
661             read_done: false,
662         }
663     }
664 
poll<C: ShareCallHolder>( &mut self, cx: &mut Context, call: &mut C, skip_finish_check: bool, ) -> Poll<Option<Result<MessageReader>>>665     fn poll<C: ShareCallHolder>(
666         &mut self,
667         cx: &mut Context,
668         call: &mut C,
669         skip_finish_check: bool,
670     ) -> Poll<Option<Result<MessageReader>>> {
671         if !skip_finish_check {
672             let mut finished = false;
673             if let Some(close_f) = &mut self.close_f {
674                 if Pin::new(close_f).poll(cx)?.is_ready() {
675                     // Don't return immediately, there may be pending data.
676                     finished = true;
677                 }
678             }
679             if finished {
680                 self.close_f.take();
681             }
682         }
683 
684         let mut bytes = None;
685         if !self.read_done {
686             if let Some(msg_f) = &mut self.msg_f {
687                 bytes = ready!(Pin::new(msg_f).poll(cx)?).message_reader;
688                 if bytes.is_none() {
689                     self.read_done = true;
690                 }
691             }
692         }
693 
694         if self.read_done {
695             if self.close_f.is_none() {
696                 return Poll::Ready(None);
697             }
698             return Poll::Pending;
699         }
700 
701         // so msg_f must be either stale or not initialized yet.
702         self.msg_f.take();
703         let msg_f = call.call(|c| c.call.start_recv_message())?;
704         self.msg_f = Some(msg_f);
705         if bytes.is_none() {
706             self.poll(cx, call, true)
707         } else {
708             Poll::Ready(bytes.map(Ok))
709         }
710     }
711 
712     // Cancel the call if we still have some messages or did not
713     // receive status code.
on_drop<C: ShareCallHolder>(&self, call: &mut C)714     fn on_drop<C: ShareCallHolder>(&self, call: &mut C) {
715         if !self.read_done || self.close_f.is_some() {
716             call.call(|c| c.call.cancel());
717         }
718     }
719 }
720 
721 /// Flags for write operations.
722 #[derive(Default, Clone, Copy)]
723 pub struct WriteFlags {
724     flags: u32,
725 }
726 
727 impl WriteFlags {
728     /// Hint that the write may be buffered and need not go out on the wire immediately.
729     ///
730     /// gRPC is free to buffer the message until the next non-buffered write, or until write stream
731     /// completion, but it need not buffer completely or at all.
buffer_hint(mut self, need_buffered: bool) -> WriteFlags732     pub fn buffer_hint(mut self, need_buffered: bool) -> WriteFlags {
733         client::change_flag(
734             &mut self.flags,
735             grpc_sys::GRPC_WRITE_BUFFER_HINT,
736             need_buffered,
737         );
738         self
739     }
740 
741     /// Force compression to be disabled.
force_no_compress(mut self, no_compress: bool) -> WriteFlags742     pub fn force_no_compress(mut self, no_compress: bool) -> WriteFlags {
743         client::change_flag(
744             &mut self.flags,
745             grpc_sys::GRPC_WRITE_NO_COMPRESS,
746             no_compress,
747         );
748         self
749     }
750 
751     /// Get whether buffer hint is enabled.
get_buffer_hint(self) -> bool752     pub fn get_buffer_hint(self) -> bool {
753         (self.flags & grpc_sys::GRPC_WRITE_BUFFER_HINT) != 0
754     }
755 
756     /// Get whether compression is disabled.
get_force_no_compress(self) -> bool757     pub fn get_force_no_compress(self) -> bool {
758         (self.flags & grpc_sys::GRPC_WRITE_NO_COMPRESS) != 0
759     }
760 }
761 
762 /// A helper struct for constructing Sink object for batch requests.
763 struct SinkBase {
764     // Batch job to be executed in `poll_ready`.
765     batch_f: Option<BatchFuture>,
766     headers: Metadata,
767     send_metadata: bool,
768     // Flag to indicate if enhance batch strategy. This behavior will modify the `buffer_hint` to batch
769     // messages as much as possible.
770     enhance_buffer_strategy: bool,
771     // Buffer used to store the data to be sent, send out the last data in this round of `start_send`.
772     buffer: GrpcSlice,
773     // Write flags used to control the data to be sent in `buffer`.
774     buf_flags: Option<WriteFlags>,
775     // Used to records whether a message in which `buffer_hint` is false exists.
776     // Note: only used in enhanced buffer strategy.
777     last_buf_hint: bool,
778 }
779 
780 impl SinkBase {
new(send_metadata: bool) -> SinkBase781     fn new(send_metadata: bool) -> SinkBase {
782         SinkBase {
783             batch_f: None,
784             headers: MetadataBuilder::new().build(),
785             send_metadata,
786             enhance_buffer_strategy: false,
787             buffer: GrpcSlice::default(),
788             buf_flags: None,
789             last_buf_hint: true,
790         }
791     }
792 
start_send<T, C: ShareCallHolder>( &mut self, call: &mut C, t: &T, flags: WriteFlags, ser: SerializeFn<T>, call_flags: u32, ) -> Result<()>793     fn start_send<T, C: ShareCallHolder>(
794         &mut self,
795         call: &mut C,
796         t: &T,
797         flags: WriteFlags,
798         ser: SerializeFn<T>,
799         call_flags: u32,
800     ) -> Result<()> {
801         // temporary fix: buffer hint with send meta will not send out any metadata.
802         // note: only the first message can enter this code block.
803         if self.send_metadata {
804             ser(t, &mut self.buffer)?;
805             self.buf_flags = Some(flags);
806             self.start_send_buffer_message(false, call, call_flags)?;
807             self.send_metadata = false;
808             return Ok(());
809         }
810 
811         // If there is already a buffered message waiting to be sent, set `buffer_hint` to true to indicate
812         // that this is not the last message.
813         if self.buf_flags.is_some() {
814             self.start_send_buffer_message(true, call, call_flags)?;
815         }
816 
817         ser(t, &mut self.buffer)?;
818         let hint = flags.get_buffer_hint();
819         self.last_buf_hint &= hint;
820         self.buf_flags = Some(flags);
821 
822         // If sink disable batch, start sending the message in buffer immediately.
823         if !self.enhance_buffer_strategy {
824             self.start_send_buffer_message(hint, call, call_flags)?;
825         }
826 
827         Ok(())
828     }
829 
830     #[inline]
poll_ready(&mut self, cx: &mut Context) -> Poll<Result<()>>831     fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<()>> {
832         match &mut self.batch_f {
833             None => return Poll::Ready(Ok(())),
834             Some(f) => {
835                 ready!(Pin::new(f).poll(cx)?);
836             }
837         }
838         self.batch_f.take();
839         Poll::Ready(Ok(()))
840     }
841 
842     #[inline]
poll_flush<C: ShareCallHolder>( &mut self, cx: &mut Context, call: &mut C, call_flags: u32, ) -> Poll<Result<()>>843     fn poll_flush<C: ShareCallHolder>(
844         &mut self,
845         cx: &mut Context,
846         call: &mut C,
847         call_flags: u32,
848     ) -> Poll<Result<()>> {
849         if self.batch_f.is_some() {
850             ready!(self.poll_ready(cx)?);
851         }
852         if self.buf_flags.is_some() {
853             self.start_send_buffer_message(self.last_buf_hint, call, call_flags)?;
854             ready!(self.poll_ready(cx)?);
855         }
856         self.last_buf_hint = true;
857         Poll::Ready(Ok(()))
858     }
859 
860     #[inline]
start_send_buffer_message<C: ShareCallHolder>( &mut self, buffer_hint: bool, call: &mut C, call_flags: u32, ) -> Result<()>861     fn start_send_buffer_message<C: ShareCallHolder>(
862         &mut self,
863         buffer_hint: bool,
864         call: &mut C,
865         call_flags: u32,
866     ) -> Result<()> {
867         // `start_send` is supposed to be called after `poll_ready` returns ready.
868         assert!(self.batch_f.is_none());
869 
870         let buffer = &mut self.buffer;
871         let mut flags = self.buf_flags.unwrap();
872         flags = flags.buffer_hint(buffer_hint);
873 
874         let headers = if self.send_metadata {
875             Some(&mut self.headers)
876         } else {
877             None
878         };
879 
880         let write_f = call.call(|c| {
881             c.call
882                 .start_send_message(buffer, flags.flags, headers, call_flags)
883         })?;
884         self.batch_f = Some(write_f);
885         if !self.buffer.is_inline() {
886             self.buffer = GrpcSlice::default();
887         }
888         self.buf_flags.take();
889         Ok(())
890     }
891 }
892