1 use bencher::{benchmark_group, benchmark_main, Bencher};
2 use bytes::{Buf, BufMut, Bytes, BytesMut};
3 use http_body::Body;
4 use std::{
5     fmt::{Error, Formatter},
6     pin::Pin,
7     task::{Context, Poll},
8 };
9 use tonic::{codec::DecodeBuf, codec::Decoder, Status, Streaming};
10 
11 macro_rules! bench {
12     ($name:ident, $message_size:expr, $chunk_size:expr, $message_count:expr) => {
13         fn $name(b: &mut Bencher) {
14             let rt = tokio::runtime::Builder::new_multi_thread()
15                 .build()
16                 .expect("runtime");
17 
18             let payload = make_payload($message_size, $message_count);
19             let body = MockBody::new(payload, $chunk_size);
20             b.bytes = body.len() as u64;
21 
22             b.iter(|| {
23                 rt.block_on(async {
24                     let decoder = MockDecoder::new($message_size);
25                     let mut stream = Streaming::new_request(decoder, body.clone(), None, None);
26 
27                     let mut count = 0;
28                     while let Some(msg) = stream.message().await.unwrap() {
29                         assert_eq!($message_size, msg.len());
30                         count += 1;
31                     }
32 
33                     assert_eq!(count, $message_count);
34                     assert!(stream.trailers().await.unwrap().is_none());
35                 })
36             })
37         }
38     };
39 }
40 
41 #[derive(Clone)]
42 struct MockBody {
43     data: Bytes,
44     chunk_size: usize,
45 }
46 
47 impl MockBody {
new(data: Bytes, chunk_size: usize) -> Self48     pub fn new(data: Bytes, chunk_size: usize) -> Self {
49         MockBody { data, chunk_size }
50     }
51 
len(&self) -> usize52     pub fn len(&self) -> usize {
53         self.data.len()
54     }
55 }
56 
57 impl Body for MockBody {
58     type Data = Bytes;
59     type Error = Status;
60 
poll_data( mut self: Pin<&mut Self>, _: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>61     fn poll_data(
62         mut self: Pin<&mut Self>,
63         _: &mut Context<'_>,
64     ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
65         if self.data.has_remaining() {
66             let split = std::cmp::min(self.chunk_size, self.data.remaining());
67             Poll::Ready(Some(Ok(self.data.split_to(split))))
68         } else {
69             Poll::Ready(None)
70         }
71     }
72 
poll_trailers( self: Pin<&mut Self>, _: &mut Context<'_>, ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>>73     fn poll_trailers(
74         self: Pin<&mut Self>,
75         _: &mut Context<'_>,
76     ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
77         Poll::Ready(Ok(None))
78     }
79 }
80 
81 impl std::fmt::Debug for MockBody {
fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>82     fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
83         let sample = self.data.iter().take(10).collect::<Vec<_>>();
84         write!(f, "{:?}...({})", sample, self.data.len())
85     }
86 }
87 
88 #[derive(Debug, Clone)]
89 struct MockDecoder {
90     message_size: usize,
91 }
92 
93 impl MockDecoder {
new(message_size: usize) -> Self94     fn new(message_size: usize) -> Self {
95         MockDecoder { message_size }
96     }
97 }
98 
99 impl Decoder for MockDecoder {
100     type Item = Vec<u8>;
101     type Error = Status;
102 
decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>103     fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
104         let out = Vec::from(buf.chunk());
105         buf.advance(self.message_size);
106         Ok(Some(out))
107     }
108 }
109 
make_payload(message_length: usize, message_count: usize) -> Bytes110 fn make_payload(message_length: usize, message_count: usize) -> Bytes {
111     let mut buf = BytesMut::new();
112 
113     for _ in 0..message_count {
114         let msg = vec![97u8; message_length];
115         buf.reserve(msg.len() + 5);
116         buf.put_u8(0);
117         buf.put_u32(msg.len() as u32);
118         buf.put(&msg[..]);
119     }
120 
121     buf.freeze()
122 }
123 
124 // change body chunk size only
125 bench!(chunk_size_100, 1_000, 100, 1);
126 bench!(chunk_size_500, 1_000, 500, 1);
127 bench!(chunk_size_1005, 1_000, 1_005, 1);
128 
129 // change message size only
130 bench!(message_size_1k, 1_000, 1_005, 2);
131 bench!(message_size_5k, 5_000, 1_005, 2);
132 bench!(message_size_10k, 10_000, 1_005, 2);
133 
134 // change message count only
135 bench!(message_count_1, 500, 505, 1);
136 bench!(message_count_10, 500, 505, 10);
137 bench!(message_count_20, 500, 505, 20);
138 
139 benchmark_group!(chunk_size, chunk_size_100, chunk_size_500, chunk_size_1005);
140 
141 benchmark_group!(
142     message_size,
143     message_size_1k,
144     message_size_5k,
145     message_size_10k
146 );
147 
148 benchmark_group!(
149     message_count,
150     message_count_1,
151     message_count_10,
152     message_count_20
153 );
154 
155 benchmark_main!(chunk_size, message_size, message_count);
156