xref: /aosp_15_r20/external/perfetto/src/base/http/http_server_unittest.cc (revision 6dbdd20afdafa5e3ca9b8809fa73465d530080dc)
1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "perfetto/ext/base/http/http_server.h"
18 
19 #include <initializer_list>
20 #include <string>
21 
22 #include "perfetto/ext/base/string_utils.h"
23 #include "perfetto/ext/base/unix_socket.h"
24 #include "src/base/test/test_task_runner.h"
25 #include "test/gtest_and_gmock.h"
26 
27 namespace perfetto {
28 namespace base {
29 namespace {
30 
31 using testing::_;
32 using testing::Invoke;
33 using testing::InvokeWithoutArgs;
34 using testing::NiceMock;
35 
36 constexpr int kTestPort = 5127;  // Chosen with a fair dice roll.
37 
38 class MockHttpHandler : public HttpRequestHandler {
39  public:
40   MOCK_METHOD(void, OnHttpRequest, (const HttpRequest&), (override));
41   MOCK_METHOD(void,
42               OnHttpConnectionClosed,
43               (HttpServerConnection*),
44               (override));
45   MOCK_METHOD(void, OnWebsocketMessage, (const WebsocketMessage&), (override));
46 };
47 
48 class HttpCli {
49  public:
HttpCli(TestTaskRunner * ttr)50   explicit HttpCli(TestTaskRunner* ttr) : task_runner_(ttr) {
51     sock = UnixSocketRaw::CreateMayFail(SockFamily::kInet, SockType::kStream);
52     sock.SetBlocking(true);
53     sock.Connect("127.0.0.1:" + std::to_string(kTestPort));
54   }
55 
SendHttpReq(std::initializer_list<std::string> headers,const std::string & body="")56   void SendHttpReq(std::initializer_list<std::string> headers,
57                    const std::string& body = "") {
58     for (auto& header : headers)
59       sock.SendStr(header + "\r\n");
60     if (!body.empty())
61       sock.SendStr("Content-Length: " + std::to_string(body.size()) + "\r\n");
62     sock.SendStr("\r\n");
63     sock.SendStr(body);
64   }
65 
Recv(size_t min_bytes)66   std::string Recv(size_t min_bytes) {
67     static int n = 0;
68     auto checkpoint_name = "rx_" + std::to_string(n++);
69     auto checkpoint = task_runner_->CreateCheckpoint(checkpoint_name);
70     std::string rxbuf;
71     sock.SetBlocking(false);
72     task_runner_->AddFileDescriptorWatch(sock.watch_handle(), [&] {
73       char buf[1024]{};
74       auto rsize = PERFETTO_EINTR(sock.Receive(buf, sizeof(buf)));
75       if (rsize < 0)
76         return;
77       rxbuf.append(buf, static_cast<size_t>(rsize));
78       if (rsize == 0 || (min_bytes && rxbuf.length() >= min_bytes))
79         checkpoint();
80     });
81     task_runner_->RunUntilCheckpoint(checkpoint_name);
82     task_runner_->RemoveFileDescriptorWatch(sock.watch_handle());
83     return rxbuf;
84   }
85 
RecvAndWaitConnClose()86   std::string RecvAndWaitConnClose() { return Recv(0); }
87 
88   TestTaskRunner* task_runner_;
89   UnixSocketRaw sock;
90 };
91 
92 class HttpServerTest : public ::testing::Test {
93  public:
HttpServerTest()94   HttpServerTest() : srv_(&task_runner_, &handler_) { srv_.Start(kTestPort); }
95 
96   TestTaskRunner task_runner_;
97   MockHttpHandler handler_;
98   HttpServer srv_;
99 };
100 
TEST_F(HttpServerTest,GET)101 TEST_F(HttpServerTest, GET) {
102   const int kIterations = 3;
103   EXPECT_CALL(handler_, OnHttpRequest(_))
104       .Times(kIterations)
105       .WillRepeatedly(Invoke([](const HttpRequest& req) {
106         EXPECT_EQ(req.uri.ToStdString(), "/foo/bar");
107         EXPECT_EQ(req.method.ToStdString(), "GET");
108         EXPECT_EQ(req.origin.ToStdString(), "https://example.com");
109         EXPECT_EQ("42",
110                   req.GetHeader("X-header").value_or("N/A").ToStdString());
111         EXPECT_EQ("foo",
112                   req.GetHeader("X-header2").value_or("N/A").ToStdString());
113         EXPECT_FALSE(req.is_websocket_handshake);
114         req.conn->SendResponseAndClose("200 OK", {}, "<html>");
115       }));
116   EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(kIterations);
117 
118   for (int i = 0; i < 3; i++) {
119     HttpCli cli(&task_runner_);
120     cli.SendHttpReq(
121         {
122             "GET /foo/bar HTTP/1.1",        //
123             "Origin: https://example.com",  //
124             "X-header: 42",                 //
125             "X-header2: foo",               //
126         },
127         "");
128     EXPECT_EQ(cli.RecvAndWaitConnClose(),
129               "HTTP/1.1 200 OK\r\n"
130               "Content-Length: 6\r\n"
131               "Connection: close\r\n"
132               "\r\n<html>");
133   }
134 }
135 
TEST_F(HttpServerTest,GET_404)136 TEST_F(HttpServerTest, GET_404) {
137   HttpCli cli(&task_runner_);
138   EXPECT_CALL(handler_, OnHttpRequest(_))
139       .WillOnce(Invoke([&](const HttpRequest& req) {
140         EXPECT_EQ(req.uri.ToStdString(), "/404");
141         EXPECT_EQ(req.method.ToStdString(), "GET");
142         req.conn->SendResponseAndClose("404 Not Found");
143       }));
144   cli.SendHttpReq({"GET /404 HTTP/1.1"}, "");
145   EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
146   EXPECT_EQ(cli.RecvAndWaitConnClose(),
147             "HTTP/1.1 404 Not Found\r\n"
148             "Content-Length: 0\r\n"
149             "Connection: close\r\n"
150             "\r\n");
151 }
152 
TEST_F(HttpServerTest,POST)153 TEST_F(HttpServerTest, POST) {
154   HttpCli cli(&task_runner_);
155 
156   EXPECT_CALL(handler_, OnHttpRequest(_))
157       .WillOnce(Invoke([&](const HttpRequest& req) {
158         EXPECT_EQ(req.uri.ToStdString(), "/rpc");
159         EXPECT_EQ(req.method.ToStdString(), "POST");
160         EXPECT_EQ(req.origin.ToStdString(), "https://example.com");
161         EXPECT_EQ("foo", req.GetHeader("X-1").value_or("N/A").ToStdString());
162         EXPECT_EQ(req.body.ToStdString(), "the\r\npost\nbody\r\n\r\n");
163         req.conn->SendResponseAndClose("200 OK");
164       }));
165 
166   cli.SendHttpReq(
167       {"POST /rpc HTTP/1.1", "Origin: https://example.com", "X-1: foo"},
168       "the\r\npost\nbody\r\n\r\n");
169   EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
170   EXPECT_EQ(cli.RecvAndWaitConnClose(),
171             "HTTP/1.1 200 OK\r\n"
172             "Content-Length: 0\r\n"
173             "Connection: close\r\n"
174             "\r\n");
175 }
176 
177 // An unhandled request should cause a HTTP 500.
TEST_F(HttpServerTest,Unhadled_500)178 TEST_F(HttpServerTest, Unhadled_500) {
179   HttpCli cli(&task_runner_);
180   EXPECT_CALL(handler_, OnHttpRequest(_));
181   cli.SendHttpReq({"GET /unhandled HTTP/1.1"});
182   EXPECT_CALL(handler_, OnHttpConnectionClosed(_));
183   EXPECT_EQ(cli.RecvAndWaitConnClose(),
184             "HTTP/1.1 500 Internal Server Error\r\n"
185             "Content-Length: 0\r\n"
186             "Connection: close\r\n"
187             "\r\n");
188 }
189 
190 // Send three requests within the same keepalive connection.
TEST_F(HttpServerTest,POST_Keepalive)191 TEST_F(HttpServerTest, POST_Keepalive) {
192   HttpCli cli(&task_runner_);
193   static const int kNumRequests = 3;
194   int req_num = 0;
195   EXPECT_CALL(handler_, OnHttpConnectionClosed(_)).Times(1);
196   EXPECT_CALL(handler_, OnHttpRequest(_))
197       .Times(3)
198       .WillRepeatedly(Invoke([&](const HttpRequest& req) {
199         EXPECT_EQ(req.uri.ToStdString(), "/" + std::to_string(req_num));
200         EXPECT_EQ(req.method.ToStdString(), "POST");
201         EXPECT_EQ(req.body.ToStdString(), "body" + std::to_string(req_num));
202         req.conn->SendResponseHeaders("200 OK");
203         if (++req_num == kNumRequests)
204           req.conn->Close();
205       }));
206 
207   for (int i = 0; i < kNumRequests; i++) {
208     auto i_str = std::to_string(i);
209     cli.SendHttpReq({"POST /" + i_str + " HTTP/1.1", "Connection: keep-alive"},
210                     "body" + i_str);
211   }
212 
213   std::string expected_response;
214   for (int i = 0; i < kNumRequests; i++) {
215     expected_response +=
216         "HTTP/1.1 200 OK\r\n"
217         "Content-Length: 0\r\n"
218         "Connection: keep-alive\r\n"
219         "\r\n";
220   }
221   EXPECT_EQ(cli.RecvAndWaitConnClose(), expected_response);
222 }
223 
TEST_F(HttpServerTest,Websocket)224 TEST_F(HttpServerTest, Websocket) {
225   srv_.AddAllowedOrigin("http://foo.com");
226   srv_.AddAllowedOrigin("http://websocket.com");
227   for (int rep = 0; rep < 3; rep++) {
228     HttpCli cli(&task_runner_);
229     EXPECT_CALL(handler_, OnHttpRequest(_))
230         .WillOnce(Invoke([&](const HttpRequest& req) {
231           EXPECT_EQ(req.uri.ToStdString(), "/websocket");
232           EXPECT_EQ(req.method.ToStdString(), "GET");
233           EXPECT_EQ(req.origin.ToStdString(), "http://websocket.com");
234           EXPECT_TRUE(req.is_websocket_handshake);
235           req.conn->UpgradeToWebsocket(req);
236         }));
237 
238     cli.SendHttpReq({
239         "GET /websocket HTTP/1.1",                      //
240         "Origin: http://websocket.com",                 //
241         "Connection: upgrade",                          //
242         "Upgrade: websocket",                           //
243         "Sec-WebSocket-Version: 13",                    //
244         "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",  //
245     });
246     std::string expected_resp =
247         "HTTP/1.1 101 Switching Protocols\r\n"
248         "Upgrade: websocket\r\n"
249         "Connection: Upgrade\r\n"
250         "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
251         "Access-Control-Allow-Origin: http://websocket.com\r\n"
252         "Vary: Origin\r\n"
253         "\r\n";
254     EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
255 
256     for (int i = 0; i < 3; i++) {
257       EXPECT_CALL(handler_, OnWebsocketMessage(_))
258           .WillOnce(Invoke([i](const WebsocketMessage& msg) {
259             EXPECT_EQ(msg.data.ToStdString(), "test message");
260             StackString<6> resp("PONG%d", i);
261             msg.conn->SendWebsocketMessage(resp.c_str(), resp.len());
262           }));
263 
264       // A frame from a real tcpdump capture:
265       //   1... .... = Fin: True
266       //   .000 .... = Reserved: 0x0
267       //   .... 0001 = Opcode: Text (1)
268       //   1... .... = Mask: True
269       //   .000 1100 = Payload length: 12
270       //   Masking-Key: e17e8eb9
271       //   Masked payload: "test message"
272       cli.sock.SendStr(
273           "\x81\x8c\xe1\x7e\x8e\xb9\x95\x1b\xfd\xcd\xc1\x13\xeb\xca\x92\x1f\xe9"
274           "\xdc");
275       EXPECT_EQ(cli.Recv(2 + 5), "\x82\x05PONG" + std::to_string(i));
276     }
277 
278     cli.sock.Shutdown();
279     auto checkpoint_name = "ws_close_" + std::to_string(rep);
280     auto ws_close = task_runner_.CreateCheckpoint(checkpoint_name);
281     EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
282         .WillOnce(InvokeWithoutArgs(ws_close));
283     task_runner_.RunUntilCheckpoint(checkpoint_name);
284   }
285 }
286 
TEST_F(HttpServerTest,Websocket_OriginNotAllowed)287 TEST_F(HttpServerTest, Websocket_OriginNotAllowed) {
288   srv_.AddAllowedOrigin("http://websocket.com");
289   srv_.AddAllowedOrigin("http://notallowed.commando");
290   srv_.AddAllowedOrigin("http://iamnotallowed.com");
291   srv_.AddAllowedOrigin("iamnotallowed.com");
292   // The origin must match in full, including scheme. This won't match.
293   srv_.AddAllowedOrigin("notallowed.com");
294 
295   HttpCli cli(&task_runner_);
296   auto close_checkpoint = task_runner_.CreateCheckpoint("close");
297   EXPECT_CALL(handler_, OnHttpConnectionClosed(_))
298       .WillOnce(InvokeWithoutArgs(close_checkpoint));
299   EXPECT_CALL(handler_, OnHttpRequest(_))
300       .WillOnce(Invoke([&](const HttpRequest& req) {
301         EXPECT_EQ(req.origin.ToStdString(), "http://notallowed.com");
302         EXPECT_TRUE(req.is_websocket_handshake);
303         req.conn->UpgradeToWebsocket(req);
304       }));
305 
306   cli.SendHttpReq({
307       "GET /websocket HTTP/1.1",                      //
308       "Origin: http://notallowed.com",                //
309       "Connection: upgrade",                          //
310       "Upgrade: websocket",                           //
311       "Sec-WebSocket-Version: 13",                    //
312       "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",  //
313   });
314   std::string expected_resp =
315       "HTTP/1.1 403 Forbidden\r\n"
316       "Content-Length: 18\r\n"
317       "Connection: close\r\n"
318       "\r\n"
319       "Origin not allowed";
320 
321   EXPECT_EQ(cli.Recv(expected_resp.size()), expected_resp);
322   cli.sock.Shutdown();
323   task_runner_.RunUntilCheckpoint("close");
324 }
325 
326 }  // namespace
327 }  // namespace base
328 }  // namespace perfetto
329