1 use axum::{
2 extract::State,
3 routing::{get, post},
4 Extension, Json, Router, Server,
5 };
6 use hyper::server::conn::AddrIncoming;
7 use serde::{Deserialize, Serialize};
8 use std::{
9 io::BufRead,
10 process::{Command, Stdio},
11 };
12
main()13 fn main() {
14 if on_ci() {
15 install_rewrk();
16 } else {
17 ensure_rewrk_is_installed();
18 }
19
20 benchmark("minimal").run(Router::new);
21
22 benchmark("basic")
23 .path("/a/b/c")
24 .run(|| Router::new().route("/a/b/c", get(|| async { "Hello, World!" })));
25
26 benchmark("basic-merge").path("/a/b/c").run(|| {
27 let inner = Router::new().route("/a/b/c", get(|| async { "Hello, World!" }));
28 Router::new().merge(inner)
29 });
30
31 benchmark("basic-nest").path("/a/b/c").run(|| {
32 let c = Router::new().route("/c", get(|| async { "Hello, World!" }));
33 let b = Router::new().nest("/b", c);
34 Router::new().nest("/a", b)
35 });
36
37 benchmark("routing").path("/foo/bar/baz").run(|| {
38 let mut app = Router::new();
39 for a in 0..10 {
40 for b in 0..10 {
41 for c in 0..10 {
42 app = app.route(&format!("/foo-{a}/bar-{b}/baz-{c}"), get(|| async {}));
43 }
44 }
45 }
46 app.route("/foo/bar/baz", get(|| async {}))
47 });
48
49 benchmark("receive-json")
50 .method("post")
51 .headers(&[("content-type", "application/json")])
52 .body(r#"{"n": 123, "s": "hi there", "b": false}"#)
53 .run(|| Router::new().route("/", post(|_: Json<Payload>| async {})));
54
55 benchmark("send-json").run(|| {
56 Router::new().route(
57 "/",
58 get(|| async {
59 Json(Payload {
60 n: 123,
61 s: "hi there".to_owned(),
62 b: false,
63 })
64 }),
65 )
66 });
67
68 let state = AppState {
69 _string: "aaaaaaaaaaaaaaaaaa".to_owned(),
70 _vec: Vec::from([
71 "aaaaaaaaaaaaaaaaaa".to_owned(),
72 "bbbbbbbbbbbbbbbbbb".to_owned(),
73 "cccccccccccccccccc".to_owned(),
74 ]),
75 };
76
77 benchmark("extension").run(|| {
78 Router::new()
79 .route("/", get(|_: Extension<AppState>| async {}))
80 .layer(Extension(state.clone()))
81 });
82
83 benchmark("state").run(|| {
84 Router::new()
85 .route("/", get(|_: State<AppState>| async {}))
86 .with_state(state.clone())
87 });
88 }
89
90 #[derive(Clone)]
91 struct AppState {
92 _string: String,
93 _vec: Vec<String>,
94 }
95
96 #[derive(Deserialize, Serialize)]
97 struct Payload {
98 n: u32,
99 s: String,
100 b: bool,
101 }
102
benchmark(name: &'static str) -> BenchmarkBuilder103 fn benchmark(name: &'static str) -> BenchmarkBuilder {
104 BenchmarkBuilder {
105 name,
106 path: None,
107 method: None,
108 headers: None,
109 body: None,
110 }
111 }
112
113 struct BenchmarkBuilder {
114 name: &'static str,
115 path: Option<&'static str>,
116 method: Option<&'static str>,
117 headers: Option<&'static [(&'static str, &'static str)]>,
118 body: Option<&'static str>,
119 }
120
121 macro_rules! config_method {
122 ($name:ident, $ty:ty) => {
123 fn $name(mut self, $name: $ty) -> Self {
124 self.$name = Some($name);
125 self
126 }
127 };
128 }
129
130 impl BenchmarkBuilder {
131 config_method!(path, &'static str);
132 config_method!(method, &'static str);
133 config_method!(headers, &'static [(&'static str, &'static str)]);
134 config_method!(body, &'static str);
135
run<F>(self, f: F) where F: FnOnce() -> Router<()>,136 fn run<F>(self, f: F)
137 where
138 F: FnOnce() -> Router<()>,
139 {
140 // support only running some benchmarks with
141 // ```
142 // cargo bench -- routing send-json
143 // ```
144 let args = std::env::args().collect::<Vec<_>>();
145 if args.len() != 1 {
146 let names = &args[1..args.len() - 1];
147 if !names.is_empty() && !names.contains(&self.name.to_owned()) {
148 return;
149 }
150 }
151
152 let app = f();
153
154 let rt = tokio::runtime::Builder::new_multi_thread()
155 .enable_all()
156 .build()
157 .unwrap();
158
159 let listener = rt
160 .block_on(tokio::net::TcpListener::bind("0.0.0.0:0"))
161 .unwrap();
162 let addr = listener.local_addr().unwrap();
163
164 std::thread::spawn(move || {
165 rt.block_on(async move {
166 let incoming = AddrIncoming::from_listener(listener).unwrap();
167 Server::builder(incoming)
168 .serve(app.into_make_service())
169 .await
170 .unwrap();
171 });
172 });
173
174 let mut cmd = Command::new("rewrk");
175 cmd.stdout(Stdio::piped());
176
177 cmd.arg("--host");
178 cmd.arg(format!("http://{}{}", addr, self.path.unwrap_or("")));
179
180 cmd.args(["--connections", "10"]);
181 cmd.args(["--threads", "10"]);
182
183 if on_ci() {
184 // don't slow down CI by running the benchmarks for too long
185 // but do run them for a bit
186 cmd.args(["--duration", "1s"]);
187 } else {
188 cmd.args(["--duration", "10s"]);
189 }
190
191 if let Some(method) = self.method {
192 cmd.args(["--method", method]);
193 }
194
195 for (key, value) in self.headers.into_iter().flatten() {
196 cmd.arg("--header");
197 cmd.arg(format!("{key}: {value}"));
198 }
199
200 if let Some(body) = self.body {
201 cmd.args(["--body", body]);
202 }
203
204 eprintln!("Running {:?} benchmark", self.name);
205
206 // indent output from `rewrk` so its easier to read when running multiple benchmarks
207 let mut child = cmd.spawn().unwrap();
208 let stdout = child.stdout.take().unwrap();
209 let stdout = std::io::BufReader::new(stdout);
210 for line in stdout.lines() {
211 let line = line.unwrap();
212 println!(" {line}");
213 }
214
215 let status = child.wait().unwrap();
216
217 if !status.success() {
218 eprintln!("`rewrk` command failed");
219 std::process::exit(status.code().unwrap());
220 }
221 }
222 }
223
install_rewrk()224 fn install_rewrk() {
225 println!("installing rewrk");
226 let mut cmd = Command::new("cargo");
227 cmd.args([
228 "install",
229 "rewrk",
230 "--git",
231 "https://github.com/ChillFish8/rewrk.git",
232 ]);
233 let status = cmd
234 .status()
235 .unwrap_or_else(|_| panic!("failed to install rewrk"));
236 if !status.success() {
237 panic!("failed to install rewrk");
238 }
239 }
240
ensure_rewrk_is_installed()241 fn ensure_rewrk_is_installed() {
242 let mut cmd = Command::new("rewrk");
243 cmd.arg("--help");
244 cmd.stdout(Stdio::null());
245 cmd.stderr(Stdio::null());
246 cmd.status().unwrap_or_else(|_| {
247 panic!("rewrk is not installed. See https://github.com/lnx-search/rewrk")
248 });
249 }
250
on_ci() -> bool251 fn on_ci() -> bool {
252 std::env::var("GITHUB_ACTIONS").is_ok()
253 }
254