1 use crate::report::BenchmarkId as InternalBenchmarkId;
2 use crate::Throughput;
3 use std::cell::RefCell;
4 use std::convert::TryFrom;
5 use std::io::{Read, Write};
6 use std::mem::size_of;
7 use std::net::TcpStream;
8 
9 #[derive(Debug)]
10 pub enum MessageError {
11     Deserialization(ciborium::de::Error<std::io::Error>),
12     Serialization(ciborium::ser::Error<std::io::Error>),
13     Io(std::io::Error),
14 }
15 impl From<ciborium::de::Error<std::io::Error>> for MessageError {
from(other: ciborium::de::Error<std::io::Error>) -> Self16     fn from(other: ciborium::de::Error<std::io::Error>) -> Self {
17         MessageError::Deserialization(other)
18     }
19 }
20 impl From<ciborium::ser::Error<std::io::Error>> for MessageError {
from(other: ciborium::ser::Error<std::io::Error>) -> Self21     fn from(other: ciborium::ser::Error<std::io::Error>) -> Self {
22         MessageError::Serialization(other)
23     }
24 }
25 impl From<std::io::Error> for MessageError {
from(other: std::io::Error) -> Self26     fn from(other: std::io::Error) -> Self {
27         MessageError::Io(other)
28     }
29 }
30 impl std::fmt::Display for MessageError {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result31     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32         match self {
33             MessageError::Deserialization(error) => write!(
34                 f,
35                 "Failed to deserialize message to Criterion.rs benchmark:\n{}",
36                 error
37             ),
38             MessageError::Serialization(error) => write!(
39                 f,
40                 "Failed to serialize message to Criterion.rs benchmark:\n{}",
41                 error
42             ),
43             MessageError::Io(error) => write!(
44                 f,
45                 "Failed to read or write message to Criterion.rs benchmark:\n{}",
46                 error
47             ),
48         }
49     }
50 }
51 impl std::error::Error for MessageError {
source(&self) -> Option<&(dyn std::error::Error + 'static)>52     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
53         match self {
54             MessageError::Deserialization(err) => Some(err),
55             MessageError::Serialization(err) => Some(err),
56             MessageError::Io(err) => Some(err),
57         }
58     }
59 }
60 
61 // Use str::len as a const fn once we bump MSRV over 1.39.
62 const RUNNER_MAGIC_NUMBER: &str = "cargo-criterion";
63 const RUNNER_HELLO_SIZE: usize = 15 //RUNNER_MAGIC_NUMBER.len() // magic number
64     + (size_of::<u8>() * 3); // version number
65 
66 const BENCHMARK_MAGIC_NUMBER: &str = "Criterion";
67 const BENCHMARK_HELLO_SIZE: usize = 9 //BENCHMARK_MAGIC_NUMBER.len() // magic number
68     + (size_of::<u8>() * 3) // version number
69     + size_of::<u16>() // protocol version
70     + size_of::<u16>(); // protocol format
71 const PROTOCOL_VERSION: u16 = 1;
72 const PROTOCOL_FORMAT: u16 = 1;
73 
74 #[derive(Debug)]
75 struct InnerConnection {
76     socket: TcpStream,
77     receive_buffer: Vec<u8>,
78     send_buffer: Vec<u8>,
79     // runner_version: [u8; 3],
80 }
81 impl InnerConnection {
new(mut socket: TcpStream) -> Result<Self, std::io::Error>82     pub fn new(mut socket: TcpStream) -> Result<Self, std::io::Error> {
83         // read the runner-hello
84         let mut hello_buf = [0u8; RUNNER_HELLO_SIZE];
85         socket.read_exact(&mut hello_buf)?;
86         assert_eq!(
87             &hello_buf[0..RUNNER_MAGIC_NUMBER.len()],
88             RUNNER_MAGIC_NUMBER.as_bytes(),
89             "Not connected to cargo-criterion."
90         );
91 
92         let i = RUNNER_MAGIC_NUMBER.len();
93         let runner_version = [hello_buf[i], hello_buf[i + 1], hello_buf[i + 2]];
94 
95         info!("Runner version: {:?}", runner_version);
96 
97         // now send the benchmark-hello
98         let mut hello_buf = [0u8; BENCHMARK_HELLO_SIZE];
99         hello_buf[0..BENCHMARK_MAGIC_NUMBER.len()]
100             .copy_from_slice(BENCHMARK_MAGIC_NUMBER.as_bytes());
101         let mut i = BENCHMARK_MAGIC_NUMBER.len();
102         hello_buf[i] = 0;
103         hello_buf[i + 1] = 0;
104         hello_buf[i + 2] = 0;
105         i += 3;
106         hello_buf[i..i + 2].clone_from_slice(&PROTOCOL_VERSION.to_be_bytes());
107         i += 2;
108         hello_buf[i..i + 2].clone_from_slice(&PROTOCOL_FORMAT.to_be_bytes());
109 
110         socket.write_all(&hello_buf)?;
111 
112         Ok(InnerConnection {
113             socket,
114             receive_buffer: vec![],
115             send_buffer: vec![],
116             // runner_version,
117         })
118     }
119 
120     #[allow(dead_code)]
recv(&mut self) -> Result<IncomingMessage, MessageError>121     pub fn recv(&mut self) -> Result<IncomingMessage, MessageError> {
122         let mut length_buf = [0u8; 4];
123         self.socket.read_exact(&mut length_buf)?;
124         let length = u32::from_be_bytes(length_buf);
125         self.receive_buffer.resize(length as usize, 0u8);
126         self.socket.read_exact(&mut self.receive_buffer)?;
127         let value = ciborium::de::from_reader(&self.receive_buffer[..])?;
128         Ok(value)
129     }
130 
send(&mut self, message: &OutgoingMessage) -> Result<(), MessageError>131     pub fn send(&mut self, message: &OutgoingMessage) -> Result<(), MessageError> {
132         self.send_buffer.truncate(0);
133         ciborium::ser::into_writer(message, &mut self.send_buffer)?;
134         let size = u32::try_from(self.send_buffer.len()).unwrap();
135         let length_buf = size.to_be_bytes();
136         self.socket.write_all(&length_buf)?;
137         self.socket.write_all(&self.send_buffer)?;
138         Ok(())
139     }
140 }
141 
142 /// This is really just a holder to allow us to send messages through a shared reference to the
143 /// connection.
144 #[derive(Debug)]
145 pub struct Connection {
146     inner: RefCell<InnerConnection>,
147 }
148 impl Connection {
new(socket: TcpStream) -> Result<Self, std::io::Error>149     pub fn new(socket: TcpStream) -> Result<Self, std::io::Error> {
150         Ok(Connection {
151             inner: RefCell::new(InnerConnection::new(socket)?),
152         })
153     }
154 
155     #[allow(dead_code)]
recv(&self) -> Result<IncomingMessage, MessageError>156     pub fn recv(&self) -> Result<IncomingMessage, MessageError> {
157         self.inner.borrow_mut().recv()
158     }
159 
send(&self, message: &OutgoingMessage) -> Result<(), MessageError>160     pub fn send(&self, message: &OutgoingMessage) -> Result<(), MessageError> {
161         self.inner.borrow_mut().send(message)
162     }
163 
serve_value_formatter( &self, formatter: &dyn crate::measurement::ValueFormatter, ) -> Result<(), MessageError>164     pub fn serve_value_formatter(
165         &self,
166         formatter: &dyn crate::measurement::ValueFormatter,
167     ) -> Result<(), MessageError> {
168         loop {
169             let response = match self.recv()? {
170                 IncomingMessage::FormatValue { value } => OutgoingMessage::FormattedValue {
171                     value: formatter.format_value(value),
172                 },
173                 IncomingMessage::FormatThroughput { value, throughput } => {
174                     OutgoingMessage::FormattedValue {
175                         value: formatter.format_throughput(&throughput, value),
176                     }
177                 }
178                 IncomingMessage::ScaleValues {
179                     typical_value,
180                     mut values,
181                 } => {
182                     let unit = formatter.scale_values(typical_value, &mut values);
183                     OutgoingMessage::ScaledValues {
184                         unit,
185                         scaled_values: values,
186                     }
187                 }
188                 IncomingMessage::ScaleThroughputs {
189                     typical_value,
190                     throughput,
191                     mut values,
192                 } => {
193                     let unit = formatter.scale_throughputs(typical_value, &throughput, &mut values);
194                     OutgoingMessage::ScaledValues {
195                         unit,
196                         scaled_values: values,
197                     }
198                 }
199                 IncomingMessage::ScaleForMachines { mut values } => {
200                     let unit = formatter.scale_for_machines(&mut values);
201                     OutgoingMessage::ScaledValues {
202                         unit,
203                         scaled_values: values,
204                     }
205                 }
206                 IncomingMessage::Continue => break,
207                 _ => panic!(),
208             };
209             self.send(&response)?;
210         }
211         Ok(())
212     }
213 }
214 
215 /// Enum defining the messages we can receive
216 #[derive(Debug, Deserialize)]
217 pub enum IncomingMessage {
218     // Value formatter requests
219     FormatValue {
220         value: f64,
221     },
222     FormatThroughput {
223         value: f64,
224         throughput: Throughput,
225     },
226     ScaleValues {
227         typical_value: f64,
228         values: Vec<f64>,
229     },
230     ScaleThroughputs {
231         typical_value: f64,
232         values: Vec<f64>,
233         throughput: Throughput,
234     },
235     ScaleForMachines {
236         values: Vec<f64>,
237     },
238     Continue,
239 
240     __Other,
241 }
242 
243 /// Enum defining the messages we can send
244 #[derive(Debug, Serialize)]
245 pub enum OutgoingMessage<'a> {
246     BeginningBenchmarkGroup {
247         group: &'a str,
248     },
249     FinishedBenchmarkGroup {
250         group: &'a str,
251     },
252     BeginningBenchmark {
253         id: RawBenchmarkId,
254     },
255     SkippingBenchmark {
256         id: RawBenchmarkId,
257     },
258     Warmup {
259         id: RawBenchmarkId,
260         nanos: f64,
261     },
262     MeasurementStart {
263         id: RawBenchmarkId,
264         sample_count: u64,
265         estimate_ns: f64,
266         iter_count: u64,
267     },
268     MeasurementComplete {
269         id: RawBenchmarkId,
270         iters: &'a [f64],
271         times: &'a [f64],
272         plot_config: PlotConfiguration,
273         sampling_method: SamplingMethod,
274         benchmark_config: BenchmarkConfig,
275     },
276     // value formatter responses
277     FormattedValue {
278         value: String,
279     },
280     ScaledValues {
281         scaled_values: Vec<f64>,
282         unit: &'a str,
283     },
284 }
285 
286 // Also define serializable variants of certain things, either to avoid leaking
287 // serializability into the public interface or because the serialized form
288 // is a bit different from the regular one.
289 
290 #[derive(Debug, Serialize)]
291 pub struct RawBenchmarkId {
292     group_id: String,
293     function_id: Option<String>,
294     value_str: Option<String>,
295     throughput: Vec<Throughput>,
296 }
297 impl From<&InternalBenchmarkId> for RawBenchmarkId {
from(other: &InternalBenchmarkId) -> RawBenchmarkId298     fn from(other: &InternalBenchmarkId) -> RawBenchmarkId {
299         RawBenchmarkId {
300             group_id: other.group_id.clone(),
301             function_id: other.function_id.clone(),
302             value_str: other.value_str.clone(),
303             throughput: other.throughput.iter().cloned().collect(),
304         }
305     }
306 }
307 
308 #[derive(Debug, Serialize)]
309 pub enum AxisScale {
310     Linear,
311     Logarithmic,
312 }
313 impl From<crate::AxisScale> for AxisScale {
from(other: crate::AxisScale) -> Self314     fn from(other: crate::AxisScale) -> Self {
315         match other {
316             crate::AxisScale::Linear => AxisScale::Linear,
317             crate::AxisScale::Logarithmic => AxisScale::Logarithmic,
318         }
319     }
320 }
321 
322 #[derive(Debug, Serialize)]
323 pub struct PlotConfiguration {
324     summary_scale: AxisScale,
325 }
326 impl From<&crate::PlotConfiguration> for PlotConfiguration {
from(other: &crate::PlotConfiguration) -> Self327     fn from(other: &crate::PlotConfiguration) -> Self {
328         PlotConfiguration {
329             summary_scale: other.summary_scale.into(),
330         }
331     }
332 }
333 
334 #[derive(Debug, Serialize)]
335 struct Duration {
336     secs: u64,
337     nanos: u32,
338 }
339 impl From<std::time::Duration> for Duration {
from(other: std::time::Duration) -> Self340     fn from(other: std::time::Duration) -> Self {
341         Duration {
342             secs: other.as_secs(),
343             nanos: other.subsec_nanos(),
344         }
345     }
346 }
347 
348 #[derive(Debug, Serialize)]
349 pub struct BenchmarkConfig {
350     confidence_level: f64,
351     measurement_time: Duration,
352     noise_threshold: f64,
353     nresamples: usize,
354     sample_size: usize,
355     significance_level: f64,
356     warm_up_time: Duration,
357 }
358 impl From<&crate::benchmark::BenchmarkConfig> for BenchmarkConfig {
from(other: &crate::benchmark::BenchmarkConfig) -> Self359     fn from(other: &crate::benchmark::BenchmarkConfig) -> Self {
360         BenchmarkConfig {
361             confidence_level: other.confidence_level,
362             measurement_time: other.measurement_time.into(),
363             noise_threshold: other.noise_threshold,
364             nresamples: other.nresamples,
365             sample_size: other.sample_size,
366             significance_level: other.significance_level,
367             warm_up_time: other.warm_up_time.into(),
368         }
369     }
370 }
371 
372 /// Currently not used; defined for forwards compatibility with cargo-criterion.
373 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
374 pub enum SamplingMethod {
375     Linear,
376     Flat,
377 }
378 impl From<crate::ActualSamplingMode> for SamplingMethod {
from(other: crate::ActualSamplingMode) -> Self379     fn from(other: crate::ActualSamplingMode) -> Self {
380         match other {
381             crate::ActualSamplingMode::Flat => SamplingMethod::Flat,
382             crate::ActualSamplingMode::Linear => SamplingMethod::Linear,
383         }
384     }
385 }
386