1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "perfetto/ext/base/http/http_server.h"
18
19 #include <initializer_list>
20 #include <string>
21
22 #include "perfetto/ext/base/string_utils.h"
23 #include "perfetto/ext/base/unix_socket.h"
24 #include "src/base/test/test_task_runner.h"
25 #include "test/gtest_and_gmock.h"
26
27 namespace perfetto {
28 namespace base {
29 namespace {
30
31 using testing::_;
32 using testing::Invoke;
33 using testing::InvokeWithoutArgs;
34 using testing::NiceMock;
35
36 constexpr int kTestPort = 5127; // Chosen with a fair dice roll.
37
38 class MockHttpHandler : public HttpRequestHandler {
39 public:
40 MOCK_METHOD(void, OnHttpRequest, (const HttpRequest&), (override));
41 MOCK_METHOD(void,
42 OnHttpConnectionClosed,
43 (HttpServerConnection*),
44 (override));
45 MOCK_METHOD(void, OnWebsocketMessage, (const WebsocketMessage&), (override));
46 };
47
48 class HttpCli {
49 public:
HttpCli(TestTaskRunner * ttr)50 explicit HttpCli(TestTaskRunner* ttr) : task_runner_(ttr) {
51 sock = UnixSocketRaw::CreateMayFail(SockFamily::kInet, SockType::kStream);
52 sock.SetBlocking(true);
53 sock.Connect("127.0.0.1:" + std::to_string(kTestPort));
54 }
55
SendHttpReq(std::initializer_list<std::string> headers,const std::string & body="")56 void SendHttpReq(std::initializer_list<std::string> headers,
57 const std::string& body = "") {
58 for (auto& header : headers)
59 sock.SendStr(header + "\r\n");
60 if (!body.empty())
61 sock.SendStr("Content-Length: " + std::to_string(body.size()) + "\r\n");
62 sock.SendStr("\r\n");
63 sock.SendStr(body);
64 }
65
Recv(size_t min_bytes)66 std::string Recv(size_t min_bytes) {
67 static int n = 0;
68 auto checkpoint_name = "rx_" + std::to_string(n++);
69 auto checkpoint = task_runner_->CreateCheckpoint(checkpoint_name);
70 std::string rxbuf;
71 sock.SetBlocking(false);
72 task_runner_->AddFileDescriptorWatch(sock.watch_handle(), [&] {
73 char buf[1024]{};
74 auto rsize = PERFETTO_EINTR(sock.Receive(buf, sizeof(buf)));
75 if (rsize < 0)
76 return;
77 rxbuf.append(buf, static_cast<size_t>(rsize));
78 if (rsize == 0 || (min_bytes && rxbuf.length() >= min_bytes))
79 checkpoint();
80 });
81 task_runner_->RunUntilCheckpoint(checkpoint_name);
82 task_runner_->RemoveFileDescriptorWatch(sock.watch_handle());
83 return rxbuf;
84 }
85
RecvAndWaitConnClose()86 std::string RecvAndWaitConnClose() { return Recv(0); }
87
88 TestTaskRunner* task_runner_;
89 UnixSocketRaw sock;
90 };
91
92 class HttpServerTest : public ::testing::Test {
93 public:
HttpServerTest()94 HttpServerTest() : srv_(&task_runner_, &handler_) { srv_.Start(kTestPort); }
95
96 TestTaskRunner task_runner_;
97 MockHttpHandler handler_;
98 HttpServer srv_;
99 };
100
TEST_F(HttpServerTest,GET)101 TEST_F(HttpServerTest, GET) {
102 const int kIterations = 3;
103 EXPECT_CALL(handler_, OnHttpRequest(_))
104 .Times(kIterations)
105 .WillRepeatedly(Invoke([](const HttpRequest& req) {
106 EXPECT_EQ(req.uri.ToStdString(), "/foo/bar");
107 EXPECT_EQ(req.method.ToStdString(), "GET");
108 EXPECT_EQ(req.origin.ToStdString(), "https://example.com");
109 EXPECT_EQ("42",
110 req.GetHeader("X-header").value_or("N/A").ToStdString());
111 EXPECT_EQ("foo",
112 req.GetHeader("X-header2").value_or("N/A").ToStdString());
113 EXPECT_FALSE(req.is_websocket_handshake);
114 req.conn->SendResponseAndClose("200 OK", {}, "<html>");
115 }));
116 EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(kIterations);
117
118 for (int i = 0; i < 3; i++) {
119 HttpCli cli(&task_runner_);
120 cli.SendHttpReq(
121 {
122 "GET /foo/bar HTTP/1.1", //
123 "Origin: https://example.com", //
124 "X-header: 42", //
125 "X-header2: foo", //
126 },
127 "");
128 EXPECT_EQ(cli.RecvAndWaitConnClose(),
129 "HTTP/1.1 200 OK\r\n"
130 "Content-Length: 6\r\n"
131 "Connection: close\r\n"
132 "\r\n<html>");
133 }
134 }
135
TEST_F(HttpServerTest,GET_404)136 TEST_F(HttpServerTest, GET_404) {
137 HttpCli cli(&task_runner_);
138 EXPECT_CALL(handler_, OnHttpRequest(_))
139 .WillOnce(Invoke([&](const HttpRequest& req) {
140 EXPECT_EQ(req.uri.ToStdString(), "/404");
141 EXPECT_EQ(req.method.ToStdString(), "GET");
142 req.conn->SendResponseAndClose("404 Not Found");
143 }));
144 cli.SendHttpReq({"GET /404 HTTP/1.1"}, "");
145 EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
146 EXPECT_EQ(cli.RecvAndWaitConnClose(),
147 "HTTP/1.1 404 Not Found\r\n"
148 "Content-Length: 0\r\n"
149 "Connection: close\r\n"
150 "\r\n");
151 }
152
TEST_F(HttpServerTest,POST)153 TEST_F(HttpServerTest, POST) {
154 HttpCli cli(&task_runner_);
155
156 EXPECT_CALL(handler_, OnHttpRequest(_))
157 .WillOnce(Invoke([&](const HttpRequest& req) {
158 EXPECT_EQ(req.uri.ToStdString(), "/rpc");
159 EXPECT_EQ(req.method.ToStdString(), "POST");
160 EXPECT_EQ(req.origin.ToStdString(), "https://example.com");
161 EXPECT_EQ("foo", req.GetHeader("X-1").value_or("N/A").ToStdString());
162 EXPECT_EQ(req.body.ToStdString(), "the\r\npost\nbody\r\n\r\n");
163 req.conn->SendResponseAndClose("200 OK");
164 }));
165
166 cli.SendHttpReq(
167 {"POST /rpc HTTP/1.1", "Origin: https://example.com", "X-1: foo"},
168 "the\r\npost\nbody\r\n\r\n");
169 EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
170 EXPECT_EQ(cli.RecvAndWaitConnClose(),
171 "HTTP/1.1 200 OK\r\n"
172 "Content-Length: 0\r\n"
173 "Connection: close\r\n"
174 "\r\n");
175 }
176
177 // An unhandled request should cause a HTTP 500.
TEST_F(HttpServerTest,Unhadled_500)178 TEST_F(HttpServerTest, Unhadled_500) {
179 HttpCli cli(&task_runner_);
180 EXPECT_CALL(handler_, OnHttpRequest(_));
181 cli.SendHttpReq({"GET /unhandled HTTP/1.1"});
182 EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
183 EXPECT_EQ(cli.RecvAndWaitConnClose(),
184 "HTTP/1.1 500 Internal Server Error\r\n"
185 "Content-Length: 0\r\n"
186 "Connection: close\r\n"
187 "\r\n");
188 }
189
190 // Send three requests within the same keepalive connection.
TEST_F(HttpServerTest,POST_Keepalive)191 TEST_F(HttpServerTest, POST_Keepalive) {
192 HttpCli cli(&task_runner_);
193 static const int kNumRequests = 3;
194 int req_num = 0;
195 EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(1);
196 EXPECT_CALL(handler_, OnHttpRequest(_))
197 .Times(3)
198 .WillRepeatedly(Invoke([&](const HttpRequest& req) {
199 EXPECT_EQ(req.uri.ToStdString(), "/" + std::to_string(req_num));
200 EXPECT_EQ(req.method.ToStdString(), "POST");
201 EXPECT_EQ(req.body.ToStdString(), "body" + std::to_string(req_num));
202 req.conn->SendResponseHeaders("200 OK");
203 if (++req_num == kNumRequests)
204 req.conn->Close();
205 }));
206
207 for (int i = 0; i < kNumRequests; i++) {
208 auto i_str = std::to_string(i);
209 cli.SendHttpReq({"POST /" + i_str + " HTTP/1.1", "Connection: keep-alive"},
210 "body" + i_str);
211 }
212
213 std::string expected_response;
214 for (int i = 0; i < kNumRequests; i++) {
215 expected_response +=
216 "HTTP/1.1 200 OK\r\n"
217 "Content-Length: 0\r\n"
218 "Connection: keep-alive\r\n"
219 "\r\n";
220 }
221 EXPECT_EQ(cli.RecvAndWaitConnClose(), expected_response);
222 }
223
TEST_F(HttpServerTest,Websocket)224 TEST_F(HttpServerTest, Websocket) {
225 srv_.AddAllowedOrigin("http://foo.com");
226 srv_.AddAllowedOrigin("http://websocket.com");
227 for (int rep = 0; rep < 3; rep++) {
228 HttpCli cli(&task_runner_);
229 EXPECT_CALL(handler_, OnHttpRequest(_))
230 .WillOnce(Invoke([&](const HttpRequest& req) {
231 EXPECT_EQ(req.uri.ToStdString(), "/websocket");
232 EXPECT_EQ(req.method.ToStdString(), "GET");
233 EXPECT_EQ(req.origin.ToStdString(), "http://websocket.com");
234 EXPECT_TRUE(req.is_websocket_handshake);
235 req.conn->UpgradeToWebsocket(req);
236 }));
237
238 cli.SendHttpReq({
239 "GET /websocket HTTP/1.1", //
240 "Origin: http://websocket.com", //
241 "Connection: upgrade", //
242 "Upgrade: websocket", //
243 "Sec-WebSocket-Version: 13", //
244 "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", //
245 });
246 std::string expected_resp =
247 "HTTP/1.1 101 Switching Protocols\r\n"
248 "Upgrade: websocket\r\n"
249 "Connection: Upgrade\r\n"
250 "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
251 "Access-Control-Allow-Origin: http://websocket.com\r\n"
252 "Vary: Origin\r\n"
253 "\r\n";
254 EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
255
256 for (int i = 0; i < 3; i++) {
257 EXPECT_CALL(handler_, OnWebsocketMessage(_))
258 .WillOnce(Invoke([i](const WebsocketMessage& msg) {
259 EXPECT_EQ(msg.data.ToStdString(), "test message");
260 StackString<6> resp("PONG%d", i);
261 msg.conn->SendWebsocketMessage(resp.c_str(), resp.len());
262 }));
263
264 // A frame from a real tcpdump capture:
265 // 1... .... = Fin: True
266 // .000 .... = Reserved: 0x0
267 // .... 0001 = Opcode: Text (1)
268 // 1... .... = Mask: True
269 // .000 1100 = Payload length: 12
270 // Masking-Key: e17e8eb9
271 // Masked payload: "test message"
272 cli.sock.SendStr(
273 "\x81\x8c\xe1\x7e\x8e\xb9\x95\x1b\xfd\xcd\xc1\x13\xeb\xca\x92\x1f\xe9"
274 "\xdc");
275 EXPECT_EQ(cli.Recv(2 + 5), "\x82\x05PONG" + std::to_string(i));
276 }
277
278 cli.sock.Shutdown();
279 auto checkpoint_name = "ws_close_" + std::to_string(rep);
280 auto ws_close = task_runner_.CreateCheckpoint(checkpoint_name);
281 EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
282 .WillOnce(InvokeWithoutArgs(ws_close));
283 task_runner_.RunUntilCheckpoint(checkpoint_name);
284 }
285 }
286
TEST_F(HttpServerTest,Websocket_OriginNotAllowed)287 TEST_F(HttpServerTest, Websocket_OriginNotAllowed) {
288 srv_.AddAllowedOrigin("http://websocket.com");
289 srv_.AddAllowedOrigin("http://notallowed.commando");
290 srv_.AddAllowedOrigin("http://iamnotallowed.com");
291 srv_.AddAllowedOrigin("iamnotallowed.com");
292 // The origin must match in full, including scheme. This won't match.
293 srv_.AddAllowedOrigin("notallowed.com");
294
295 HttpCli cli(&task_runner_);
296 auto close_checkpoint = task_runner_.CreateCheckpoint("close");
297 EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
298 .WillOnce(InvokeWithoutArgs(close_checkpoint));
299 EXPECT_CALL(handler_, OnHttpRequest(_))
300 .WillOnce(Invoke([&](const HttpRequest& req) {
301 EXPECT_EQ(req.origin.ToStdString(), "http://notallowed.com");
302 EXPECT_TRUE(req.is_websocket_handshake);
303 req.conn->UpgradeToWebsocket(req);
304 }));
305
306 cli.SendHttpReq({
307 "GET /websocket HTTP/1.1", //
308 "Origin: http://notallowed.com", //
309 "Connection: upgrade", //
310 "Upgrade: websocket", //
311 "Sec-WebSocket-Version: 13", //
312 "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", //
313 });
314 std::string expected_resp =
315 "HTTP/1.1 403 Forbidden\r\n"
316 "Content-Length: 18\r\n"
317 "Connection: close\r\n"
318 "\r\n"
319 "Origin not allowed";
320
321 EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
322 cli.sock.Shutdown();
323 task_runner_.RunUntilCheckpoint("close");
324 }
325
326 } // namespace
327 } // namespace base
328 } // namespace perfetto
329