1 use axum::{
2     extract::State,
3     routing::{get, post},
4     Extension, Json, Router, Server,
5 };
6 use hyper::server::conn::AddrIncoming;
7 use serde::{Deserialize, Serialize};
8 use std::{
9     io::BufRead,
10     process::{Command, Stdio},
11 };
12 
main()13 fn main() {
14     if on_ci() {
15         install_rewrk();
16     } else {
17         ensure_rewrk_is_installed();
18     }
19 
20     benchmark("minimal").run(Router::new);
21 
22     benchmark("basic")
23         .path("/a/b/c")
24         .run(|| Router::new().route("/a/b/c", get(|| async { "Hello, World!" })));
25 
26     benchmark("basic-merge").path("/a/b/c").run(|| {
27         let inner = Router::new().route("/a/b/c", get(|| async { "Hello, World!" }));
28         Router::new().merge(inner)
29     });
30 
31     benchmark("basic-nest").path("/a/b/c").run(|| {
32         let c = Router::new().route("/c", get(|| async { "Hello, World!" }));
33         let b = Router::new().nest("/b", c);
34         Router::new().nest("/a", b)
35     });
36 
37     benchmark("routing").path("/foo/bar/baz").run(|| {
38         let mut app = Router::new();
39         for a in 0..10 {
40             for b in 0..10 {
41                 for c in 0..10 {
42                     app = app.route(&format!("/foo-{a}/bar-{b}/baz-{c}"), get(|| async {}));
43                 }
44             }
45         }
46         app.route("/foo/bar/baz", get(|| async {}))
47     });
48 
49     benchmark("receive-json")
50         .method("post")
51         .headers(&[("content-type", "application/json")])
52         .body(r#"{"n": 123, "s": "hi there", "b": false}"#)
53         .run(|| Router::new().route("/", post(|_: Json<Payload>| async {})));
54 
55     benchmark("send-json").run(|| {
56         Router::new().route(
57             "/",
58             get(|| async {
59                 Json(Payload {
60                     n: 123,
61                     s: "hi there".to_owned(),
62                     b: false,
63                 })
64             }),
65         )
66     });
67 
68     let state = AppState {
69         _string: "aaaaaaaaaaaaaaaaaa".to_owned(),
70         _vec: Vec::from([
71             "aaaaaaaaaaaaaaaaaa".to_owned(),
72             "bbbbbbbbbbbbbbbbbb".to_owned(),
73             "cccccccccccccccccc".to_owned(),
74         ]),
75     };
76 
77     benchmark("extension").run(|| {
78         Router::new()
79             .route("/", get(|_: Extension<AppState>| async {}))
80             .layer(Extension(state.clone()))
81     });
82 
83     benchmark("state").run(|| {
84         Router::new()
85             .route("/", get(|_: State<AppState>| async {}))
86             .with_state(state.clone())
87     });
88 }
89 
90 #[derive(Clone)]
91 struct AppState {
92     _string: String,
93     _vec: Vec<String>,
94 }
95 
96 #[derive(Deserialize, Serialize)]
97 struct Payload {
98     n: u32,
99     s: String,
100     b: bool,
101 }
102 
benchmark(name: &'static str) -> BenchmarkBuilder103 fn benchmark(name: &'static str) -> BenchmarkBuilder {
104     BenchmarkBuilder {
105         name,
106         path: None,
107         method: None,
108         headers: None,
109         body: None,
110     }
111 }
112 
113 struct BenchmarkBuilder {
114     name: &'static str,
115     path: Option<&'static str>,
116     method: Option<&'static str>,
117     headers: Option<&'static [(&'static str, &'static str)]>,
118     body: Option<&'static str>,
119 }
120 
121 macro_rules! config_method {
122     ($name:ident, $ty:ty) => {
123         fn $name(mut self, $name: $ty) -> Self {
124             self.$name = Some($name);
125             self
126         }
127     };
128 }
129 
130 impl BenchmarkBuilder {
131     config_method!(path, &'static str);
132     config_method!(method, &'static str);
133     config_method!(headers, &'static [(&'static str, &'static str)]);
134     config_method!(body, &'static str);
135 
run<F>(self, f: F) where F: FnOnce() -> Router<()>,136     fn run<F>(self, f: F)
137     where
138         F: FnOnce() -> Router<()>,
139     {
140         // support only running some benchmarks with
141         // ```
142         // cargo bench -- routing send-json
143         // ```
144         let args = std::env::args().collect::<Vec<_>>();
145         if args.len() != 1 {
146             let names = &args[1..args.len() - 1];
147             if !names.is_empty() && !names.contains(&self.name.to_owned()) {
148                 return;
149             }
150         }
151 
152         let app = f();
153 
154         let rt = tokio::runtime::Builder::new_multi_thread()
155             .enable_all()
156             .build()
157             .unwrap();
158 
159         let listener = rt
160             .block_on(tokio::net::TcpListener::bind("0.0.0.0:0"))
161             .unwrap();
162         let addr = listener.local_addr().unwrap();
163 
164         std::thread::spawn(move || {
165             rt.block_on(async move {
166                 let incoming = AddrIncoming::from_listener(listener).unwrap();
167                 Server::builder(incoming)
168                     .serve(app.into_make_service())
169                     .await
170                     .unwrap();
171             });
172         });
173 
174         let mut cmd = Command::new("rewrk");
175         cmd.stdout(Stdio::piped());
176 
177         cmd.arg("--host");
178         cmd.arg(format!("http://{}{}", addr, self.path.unwrap_or("")));
179 
180         cmd.args(["--connections", "10"]);
181         cmd.args(["--threads", "10"]);
182 
183         if on_ci() {
184             // don't slow down CI by running the benchmarks for too long
185             // but do run them for a bit
186             cmd.args(["--duration", "1s"]);
187         } else {
188             cmd.args(["--duration", "10s"]);
189         }
190 
191         if let Some(method) = self.method {
192             cmd.args(["--method", method]);
193         }
194 
195         for (key, value) in self.headers.into_iter().flatten() {
196             cmd.arg("--header");
197             cmd.arg(format!("{key}: {value}"));
198         }
199 
200         if let Some(body) = self.body {
201             cmd.args(["--body", body]);
202         }
203 
204         eprintln!("Running {:?} benchmark", self.name);
205 
206         // indent output from `rewrk` so its easier to read when running multiple benchmarks
207         let mut child = cmd.spawn().unwrap();
208         let stdout = child.stdout.take().unwrap();
209         let stdout = std::io::BufReader::new(stdout);
210         for line in stdout.lines() {
211             let line = line.unwrap();
212             println!("  {line}");
213         }
214 
215         let status = child.wait().unwrap();
216 
217         if !status.success() {
218             eprintln!("`rewrk` command failed");
219             std::process::exit(status.code().unwrap());
220         }
221     }
222 }
223 
install_rewrk()224 fn install_rewrk() {
225     println!("installing rewrk");
226     let mut cmd = Command::new("cargo");
227     cmd.args([
228         "install",
229         "rewrk",
230         "--git",
231         "https://github.com/ChillFish8/rewrk.git",
232     ]);
233     let status = cmd
234         .status()
235         .unwrap_or_else(|_| panic!("failed to install rewrk"));
236     if !status.success() {
237         panic!("failed to install rewrk");
238     }
239 }
240 
ensure_rewrk_is_installed()241 fn ensure_rewrk_is_installed() {
242     let mut cmd = Command::new("rewrk");
243     cmd.arg("--help");
244     cmd.stdout(Stdio::null());
245     cmd.stderr(Stdio::null());
246     cmd.status().unwrap_or_else(|_| {
247         panic!("rewrk is not installed. See https://github.com/lnx-search/rewrk")
248     });
249 }
250 
on_ci() -> bool251 fn on_ci() -> bool {
252     std::env::var("GITHUB_ACTIONS").is_ok()
253 }
254