1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/server/http_server.h"
6
7 #include <stdint.h>
8
9 #include <algorithm>
10 #include <memory>
11 #include <string_view>
12 #include <unordered_map>
13 #include <utility>
14 #include <vector>
15
16 #include "base/auto_reset.h"
17 #include "base/check_op.h"
18 #include "base/compiler_specific.h"
19 #include "base/format_macros.h"
20 #include "base/functional/bind.h"
21 #include "base/functional/callback_helpers.h"
22 #include "base/location.h"
23 #include "base/memory/ptr_util.h"
24 #include "base/memory/ref_counted.h"
25 #include "base/memory/weak_ptr.h"
26 #include "base/notreached.h"
27 #include "base/numerics/safe_conversions.h"
28 #include "base/run_loop.h"
29 #include "base/strings/string_split.h"
30 #include "base/strings/string_util.h"
31 #include "base/strings/stringprintf.h"
32 #include "base/task/single_thread_task_runner.h"
33 #include "base/test/test_future.h"
34 #include "base/time/time.h"
35 #include "net/base/address_list.h"
36 #include "net/base/io_buffer.h"
37 #include "net/base/ip_endpoint.h"
38 #include "net/base/net_errors.h"
39 #include "net/base/test_completion_callback.h"
40 #include "net/http/http_response_headers.h"
41 #include "net/http/http_util.h"
42 #include "net/log/net_log_source.h"
43 #include "net/log/net_log_with_source.h"
44 #include "net/server/http_server_request_info.h"
45 #include "net/socket/tcp_client_socket.h"
46 #include "net/socket/tcp_server_socket.h"
47 #include "net/test/gtest_util.h"
48 #include "net/test/test_with_task_environment.h"
49 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
50 #include "net/websockets/websocket_frame.h"
51 #include "testing/gmock/include/gmock/gmock.h"
52 #include "testing/gtest/include/gtest/gtest.h"
53
54 using net::test::IsOk;
55
56 namespace net {
57
58 namespace {
59
60 const int kMaxExpectedResponseLength = 2048;
61
62 class TestHttpClient {
63 public:
64 TestHttpClient() = default;
65
ConnectAndWait(const IPEndPoint & address)66 int ConnectAndWait(const IPEndPoint& address) {
67 AddressList addresses(address);
68 NetLogSource source;
69 socket_ = std::make_unique<TCPClientSocket>(addresses, nullptr, nullptr,
70 nullptr, source);
71
72 TestCompletionCallback callback;
73 int rv = socket_->Connect(callback.callback());
74 return callback.GetResult(rv);
75 }
76
Send(const std::string & data)77 void Send(const std::string& data) {
78 write_buffer_ = base::MakeRefCounted<DrainableIOBuffer>(
79 base::MakeRefCounted<StringIOBuffer>(data), data.length());
80 Write();
81 }
82
Read(std::string * message,int expected_bytes)83 bool Read(std::string* message, int expected_bytes) {
84 int total_bytes_received = 0;
85 message->clear();
86 while (total_bytes_received < expected_bytes) {
87 TestCompletionCallback callback;
88 ReadInternal(&callback);
89 int bytes_received = callback.WaitForResult();
90 if (bytes_received <= 0) {
91 return false;
92 }
93
94 total_bytes_received += bytes_received;
95 message->append(read_buffer_->data(), bytes_received);
96 }
97 return true;
98 }
99
ReadResponse(std::string * message)100 bool ReadResponse(std::string* message) {
101 if (!Read(message, 1)) {
102 return false;
103 }
104 while (!IsCompleteResponse(*message)) {
105 std::string chunk;
106 if (!Read(&chunk, 1)) {
107 return false;
108 }
109 message->append(chunk);
110 }
111 return true;
112 }
113
ExpectUsedThenDisconnectedWithNoData()114 void ExpectUsedThenDisconnectedWithNoData() {
115 // Check that the socket was opened...
116 ASSERT_TRUE(socket_->WasEverUsed());
117
118 // ...then closed when the server disconnected. Verify that the socket was
119 // closed by checking that a Read() fails.
120 std::string response;
121 ASSERT_FALSE(Read(&response, 1u));
122 ASSERT_TRUE(response.empty());
123 }
124
socket()125 TCPClientSocket& socket() { return *socket_; }
126
127 private:
Write()128 void Write() {
129 int result = socket_->Write(
130 write_buffer_.get(), write_buffer_->BytesRemaining(),
131 base::BindOnce(&TestHttpClient::OnWrite, base::Unretained(this)),
132 TRAFFIC_ANNOTATION_FOR_TESTS);
133 if (result != ERR_IO_PENDING) {
134 OnWrite(result);
135 }
136 }
137
OnWrite(int result)138 void OnWrite(int result) {
139 ASSERT_GT(result, 0);
140 write_buffer_->DidConsume(result);
141 if (write_buffer_->BytesRemaining()) {
142 Write();
143 }
144 }
145
ReadInternal(TestCompletionCallback * callback)146 void ReadInternal(TestCompletionCallback* callback) {
147 read_buffer_ =
148 base::MakeRefCounted<IOBufferWithSize>(kMaxExpectedResponseLength);
149 int result = socket_->Read(read_buffer_.get(), kMaxExpectedResponseLength,
150 callback->callback());
151 if (result != ERR_IO_PENDING) {
152 callback->callback().Run(result);
153 }
154 }
155
IsCompleteResponse(const std::string & response)156 bool IsCompleteResponse(const std::string& response) {
157 // Check end of headers first.
158 size_t end_of_headers =
159 HttpUtil::LocateEndOfHeaders(response.data(), response.size());
160 if (end_of_headers == std::string::npos) {
161 return false;
162 }
163
164 // Return true if response has data equal to or more than content length.
165 int64_t body_size = static_cast<int64_t>(response.size()) - end_of_headers;
166 DCHECK_LE(0, body_size);
167 auto headers =
168 base::MakeRefCounted<HttpResponseHeaders>(HttpUtil::AssembleRawHeaders(
169 std::string_view(response.data(), end_of_headers)));
170 return body_size >= headers->GetContentLength();
171 }
172
173 scoped_refptr<IOBufferWithSize> read_buffer_;
174 scoped_refptr<DrainableIOBuffer> write_buffer_;
175 std::unique_ptr<TCPClientSocket> socket_;
176 };
177
178 struct ReceivedRequest {
179 HttpServerRequestInfo info;
180 int connection_id;
181 };
182
183 } // namespace
184
185 class HttpServerTest : public TestWithTaskEnvironment,
186 public HttpServer::Delegate {
187 public:
188 HttpServerTest() = default;
189
SetUp()190 void SetUp() override {
191 auto server_socket =
192 std::make_unique<TCPServerSocket>(nullptr, NetLogSource());
193 server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1);
194 server_ = std::make_unique<HttpServer>(std::move(server_socket), this);
195 ASSERT_THAT(server_->GetLocalAddress(&server_address_), IsOk());
196 }
197
TearDown()198 void TearDown() override {
199 // Run the event loop some to make sure that the memory handed over to
200 // DeleteSoon gets fully freed.
201 base::RunLoop().RunUntilIdle();
202 }
203
OnConnect(int connection_id)204 void OnConnect(int connection_id) override {
205 DCHECK(connection_map_.find(connection_id) == connection_map_.end());
206 connection_map_[connection_id] = true;
207 // This is set in CreateConnection(), which must be invoked once for every
208 // expected connection.
209 quit_on_create_loop_->Quit();
210 }
211
OnHttpRequest(int connection_id,const HttpServerRequestInfo & info)212 void OnHttpRequest(int connection_id,
213 const HttpServerRequestInfo& info) override {
214 received_request_.SetValue({.info = info, .connection_id = connection_id});
215 }
216
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)217 void OnWebSocketRequest(int connection_id,
218 const HttpServerRequestInfo& info) override {
219 NOTREACHED();
220 }
221
OnWebSocketMessage(int connection_id,std::string data)222 void OnWebSocketMessage(int connection_id, std::string data) override {
223 NOTREACHED();
224 }
225
OnClose(int connection_id)226 void OnClose(int connection_id) override {
227 DCHECK(connection_map_.find(connection_id) != connection_map_.end());
228 connection_map_[connection_id] = false;
229 if (connection_id == quit_on_close_connection_) {
230 std::move(run_loop_quit_func_).Run();
231 }
232 }
233
WaitForRequest()234 ReceivedRequest WaitForRequest() { return received_request_.Take(); }
235
HasRequest() const236 bool HasRequest() const { return received_request_.IsReady(); }
237
238 // Connections should only be created using this method, which waits until
239 // both the server and the client have received the connected socket.
CreateConnection(TestHttpClient * client)240 void CreateConnection(TestHttpClient* client) {
241 ASSERT_FALSE(quit_on_create_loop_);
242 quit_on_create_loop_ = std::make_unique<base::RunLoop>();
243 EXPECT_THAT(client->ConnectAndWait(server_address_), IsOk());
244 quit_on_create_loop_->Run();
245 quit_on_create_loop_.reset();
246 }
247
RunUntilConnectionIdClosed(int connection_id)248 void RunUntilConnectionIdClosed(int connection_id) {
249 quit_on_close_connection_ = connection_id;
250 auto iter = connection_map_.find(connection_id);
251 if (iter != connection_map_.end() && !iter->second) {
252 // Already disconnected.
253 return;
254 }
255
256 base::RunLoop run_loop;
257 base::AutoReset<base::OnceClosure> run_loop_quit_func(
258 &run_loop_quit_func_, run_loop.QuitClosure());
259 run_loop.Run();
260
261 iter = connection_map_.find(connection_id);
262 ASSERT_TRUE(iter != connection_map_.end());
263 ASSERT_FALSE(iter->second);
264 }
265
HandleAcceptResult(std::unique_ptr<StreamSocket> socket)266 void HandleAcceptResult(std::unique_ptr<StreamSocket> socket) {
267 ASSERT_FALSE(quit_on_create_loop_);
268 quit_on_create_loop_ = std::make_unique<base::RunLoop>();
269 server_->accepted_socket_ = std::move(socket);
270 server_->HandleAcceptResult(OK);
271 quit_on_create_loop_->Run();
272 quit_on_create_loop_.reset();
273 }
274
connection_map()275 std::unordered_map<int, bool>& connection_map() { return connection_map_; }
276
277 protected:
278 std::unique_ptr<HttpServer> server_;
279 IPEndPoint server_address_;
280 base::OnceClosure run_loop_quit_func_;
281 std::unordered_map<int /* connection_id */, bool /* connected */>
282 connection_map_;
283
284 private:
285 base::test::TestFuture<ReceivedRequest> received_request_;
286 std::unique_ptr<base::RunLoop> quit_on_create_loop_;
287 int quit_on_close_connection_ = -1;
288 };
289
290 namespace {
291
292 class WebSocketTest : public HttpServerTest {
OnHttpRequest(int connection_id,const HttpServerRequestInfo & info)293 void OnHttpRequest(int connection_id,
294 const HttpServerRequestInfo& info) override {
295 NOTREACHED();
296 }
297
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)298 void OnWebSocketRequest(int connection_id,
299 const HttpServerRequestInfo& info) override {
300 HttpServerTest::OnHttpRequest(connection_id, info);
301 }
302
OnWebSocketMessage(int connection_id,std::string data)303 void OnWebSocketMessage(int connection_id, std::string data) override {}
304 };
305
306 class WebSocketAcceptingTest : public WebSocketTest {
307 public:
OnWebSocketRequest(int connection_id,const HttpServerRequestInfo & info)308 void OnWebSocketRequest(int connection_id,
309 const HttpServerRequestInfo& info) override {
310 HttpServerTest::OnHttpRequest(connection_id, info);
311 server_->AcceptWebSocket(connection_id, info, TRAFFIC_ANNOTATION_FOR_TESTS);
312 }
313
OnWebSocketMessage(int connection_id,std::string data)314 void OnWebSocketMessage(int connection_id, std::string data) override {
315 last_message_.SetValue(data);
316 }
317
GetMessage()318 std::string GetMessage() { return last_message_.Take(); }
319
320 private:
321 base::test::TestFuture<std::string> last_message_;
322 };
323
EncodeFrame(std::string message,WebSocketFrameHeader::OpCodeEnum op_code,bool mask,bool finish)324 std::string EncodeFrame(std::string message,
325 WebSocketFrameHeader::OpCodeEnum op_code,
326 bool mask,
327 bool finish) {
328 WebSocketFrameHeader header(op_code);
329 header.final = finish;
330 header.masked = mask;
331 header.payload_length = message.size();
332 const size_t header_size = GetWebSocketFrameHeaderSize(header);
333 std::string frame_header;
334 frame_header.resize(header_size);
335 if (mask) {
336 WebSocketMaskingKey masking_key = GenerateWebSocketMaskingKey();
337 WriteWebSocketFrameHeader(header, &masking_key, &frame_header[0],
338 base::checked_cast<int>(header_size));
339 MaskWebSocketFramePayload(masking_key, 0, &message[0], message.size());
340 } else {
341 WriteWebSocketFrameHeader(header, nullptr, &frame_header[0],
342 base::checked_cast<int>(header_size));
343 }
344 return frame_header + message;
345 }
346
TEST_F(HttpServerTest,Request)347 TEST_F(HttpServerTest, Request) {
348 TestHttpClient client;
349 CreateConnection(&client);
350 client.Send("GET /test HTTP/1.1\r\n\r\n");
351 ReceivedRequest request = WaitForRequest();
352 ASSERT_EQ("GET", request.info.method);
353 ASSERT_EQ("/test", request.info.path);
354 ASSERT_EQ("", request.info.data);
355 ASSERT_EQ(0u, request.info.headers.size());
356 ASSERT_TRUE(request.info.peer.ToString().starts_with("127.0.0.1"));
357 }
358
TEST_F(HttpServerTest,RequestBrokenTermination)359 TEST_F(HttpServerTest, RequestBrokenTermination) {
360 TestHttpClient client;
361 CreateConnection(&client);
362 client.Send("GET /test HTTP/1.1\r\n\r)");
363 RunUntilConnectionIdClosed(1);
364 EXPECT_FALSE(HasRequest());
365 client.ExpectUsedThenDisconnectedWithNoData();
366 }
367
TEST_F(HttpServerTest,RequestWithHeaders)368 TEST_F(HttpServerTest, RequestWithHeaders) {
369 TestHttpClient client;
370 CreateConnection(&client);
371 const char* const kHeaders[][3] = {
372 {"Header", ": ", "1"},
373 {"HeaderWithNoWhitespace", ":", "1"},
374 {"HeaderWithWhitespace", " : \t ", "1 1 1 \t "},
375 {"HeaderWithColon", ": ", "1:1"},
376 {"EmptyHeader", ":", ""},
377 {"EmptyHeaderWithWhitespace", ": \t ", ""},
378 {"HeaderWithNonASCII", ": ", "\xf7"},
379 };
380 std::string headers;
381 for (const auto& header : kHeaders) {
382 headers += std::string(header[0]) + header[1] + header[2] + "\r\n";
383 }
384
385 client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
386 auto request = WaitForRequest();
387 ASSERT_EQ("", request.info.data);
388
389 for (const auto& header : kHeaders) {
390 std::string field = base::ToLowerASCII(std::string(header[0]));
391 std::string value = header[2];
392 ASSERT_EQ(1u, request.info.headers.count(field)) << field;
393 ASSERT_EQ(value, request.info.headers[field]) << header[0];
394 }
395 }
396
TEST_F(HttpServerTest,RequestWithDuplicateHeaders)397 TEST_F(HttpServerTest, RequestWithDuplicateHeaders) {
398 TestHttpClient client;
399 CreateConnection(&client);
400 const char* const kHeaders[][3] = {
401 // clang-format off
402 {"FirstHeader", ": ", "1"},
403 {"DuplicateHeader", ": ", "2"},
404 {"MiddleHeader", ": ", "3"},
405 {"DuplicateHeader", ": ", "4"},
406 {"LastHeader", ": ", "5"},
407 // clang-format on
408 };
409 std::string headers;
410 for (const auto& header : kHeaders) {
411 headers += std::string(header[0]) + header[1] + header[2] + "\r\n";
412 }
413
414 client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
415 auto request = WaitForRequest();
416 ASSERT_EQ("", request.info.data);
417
418 for (const auto& header : kHeaders) {
419 std::string field = base::ToLowerASCII(std::string(header[0]));
420 std::string value = (field == "duplicateheader") ? "2,4" : header[2];
421 ASSERT_EQ(1u, request.info.headers.count(field)) << field;
422 ASSERT_EQ(value, request.info.headers[field]) << header[0];
423 }
424 }
425
TEST_F(HttpServerTest,HasHeaderValueTest)426 TEST_F(HttpServerTest, HasHeaderValueTest) {
427 TestHttpClient client;
428 CreateConnection(&client);
429 const char* const kHeaders[] = {
430 "Header: Abcd",
431 "HeaderWithNoWhitespace:E",
432 "HeaderWithWhitespace : \t f \t ",
433 "DuplicateHeader: g",
434 "HeaderWithComma: h, i ,j",
435 "DuplicateHeader: k",
436 "EmptyHeader:",
437 "EmptyHeaderWithWhitespace: \t ",
438 "HeaderWithNonASCII: \xf7",
439 };
440 std::string headers;
441 for (const char* header : kHeaders) {
442 headers += std::string(header) + "\r\n";
443 }
444
445 client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
446 auto request = WaitForRequest();
447 ASSERT_EQ("", request.info.data);
448
449 ASSERT_TRUE(request.info.HasHeaderValue("header", "abcd"));
450 ASSERT_FALSE(request.info.HasHeaderValue("header", "bc"));
451 ASSERT_TRUE(request.info.HasHeaderValue("headerwithnowhitespace", "e"));
452 ASSERT_TRUE(request.info.HasHeaderValue("headerwithwhitespace", "f"));
453 ASSERT_TRUE(request.info.HasHeaderValue("duplicateheader", "g"));
454 ASSERT_TRUE(request.info.HasHeaderValue("headerwithcomma", "h"));
455 ASSERT_TRUE(request.info.HasHeaderValue("headerwithcomma", "i"));
456 ASSERT_TRUE(request.info.HasHeaderValue("headerwithcomma", "j"));
457 ASSERT_TRUE(request.info.HasHeaderValue("duplicateheader", "k"));
458 ASSERT_FALSE(request.info.HasHeaderValue("emptyheader", "x"));
459 ASSERT_FALSE(request.info.HasHeaderValue("emptyheaderwithwhitespace", "x"));
460 ASSERT_TRUE(request.info.HasHeaderValue("headerwithnonascii", "\xf7"));
461 }
462
TEST_F(HttpServerTest,RequestWithBody)463 TEST_F(HttpServerTest, RequestWithBody) {
464 TestHttpClient client;
465 CreateConnection(&client);
466 std::string body = "a" + std::string(1 << 10, 'b') + "c";
467 client.Send(
468 base::StringPrintf("GET /test HTTP/1.1\r\n"
469 "SomeHeader: 1\r\n"
470 "Content-Length: %" PRIuS "\r\n\r\n%s",
471 body.length(), body.c_str()));
472 auto request = WaitForRequest();
473 ASSERT_EQ(2u, request.info.headers.size());
474 ASSERT_EQ(body.length(), request.info.data.length());
475 ASSERT_EQ('a', body[0]);
476 ASSERT_EQ('c', *body.rbegin());
477 }
478
479 // Tests that |HttpServer::HandleReadResult| will ignore Upgrade header if value
480 // is not WebSocket.
TEST_F(HttpServerTest,UpgradeIgnored)481 TEST_F(HttpServerTest, UpgradeIgnored) {
482 TestHttpClient client;
483 CreateConnection(&client);
484 client.Send(
485 "GET /test HTTP/1.1\r\n"
486 "Upgrade: h2c\r\n"
487 "Connection: SomethingElse, Upgrade\r\n"
488 "\r\n");
489 WaitForRequest();
490 }
491
TEST_F(WebSocketTest,RequestWebSocket)492 TEST_F(WebSocketTest, RequestWebSocket) {
493 TestHttpClient client;
494 CreateConnection(&client);
495 client.Send(
496 "GET /test HTTP/1.1\r\n"
497 "Upgrade: WebSocket\r\n"
498 "Connection: SomethingElse, Upgrade\r\n"
499 "Sec-WebSocket-Version: 8\r\n"
500 "Sec-WebSocket-Key: key\r\n"
501 "\r\n");
502 WaitForRequest();
503 }
504
TEST_F(WebSocketTest,RequestWebSocketTrailingJunk)505 TEST_F(WebSocketTest, RequestWebSocketTrailingJunk) {
506 TestHttpClient client;
507 CreateConnection(&client);
508 client.Send(
509 "GET /test HTTP/1.1\r\n"
510 "Upgrade: WebSocket\r\n"
511 "Connection: SomethingElse, Upgrade\r\n"
512 "Sec-WebSocket-Version: 8\r\n"
513 "Sec-WebSocket-Key: key\r\n"
514 "\r\nHello? Anyone");
515 RunUntilConnectionIdClosed(1);
516 client.ExpectUsedThenDisconnectedWithNoData();
517 }
518
TEST_F(WebSocketAcceptingTest,SendPingFrameWithNoMessage)519 TEST_F(WebSocketAcceptingTest, SendPingFrameWithNoMessage) {
520 TestHttpClient client;
521 CreateConnection(&client);
522 std::string response;
523 client.Send(
524 "GET /test HTTP/1.1\r\n"
525 "Upgrade: WebSocket\r\n"
526 "Connection: SomethingElse, Upgrade\r\n"
527 "Sec-WebSocket-Version: 8\r\n"
528 "Sec-WebSocket-Key: key\r\n\r\n");
529 WaitForRequest();
530 ASSERT_TRUE(client.ReadResponse(&response));
531 const std::string message = "";
532 const std::string ping_frame =
533 EncodeFrame(message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
534 /* mask= */ true, /* finish= */ true);
535 const std::string pong_frame =
536 EncodeFrame(message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
537 /* mask= */ false, /* finish= */ true);
538 client.Send(ping_frame);
539 ASSERT_TRUE(client.Read(&response, pong_frame.length()));
540 EXPECT_EQ(response, pong_frame);
541 }
542
TEST_F(WebSocketAcceptingTest,SendPingFrameWithMessage)543 TEST_F(WebSocketAcceptingTest, SendPingFrameWithMessage) {
544 TestHttpClient client;
545 CreateConnection(&client);
546 std::string response;
547 client.Send(
548 "GET /test HTTP/1.1\r\n"
549 "Upgrade: WebSocket\r\n"
550 "Connection: SomethingElse, Upgrade\r\n"
551 "Sec-WebSocket-Version: 8\r\n"
552 "Sec-WebSocket-Key: key\r\n\r\n");
553 WaitForRequest();
554 ASSERT_TRUE(client.ReadResponse(&response));
555 const std::string message = "hello";
556 const std::string ping_frame =
557 EncodeFrame(message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
558 /* mask= */ true, /* finish= */ true);
559 const std::string pong_frame =
560 EncodeFrame(message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
561 /* mask= */ false, /* finish= */ true);
562 client.Send(ping_frame);
563 ASSERT_TRUE(client.Read(&response, pong_frame.length()));
564 EXPECT_EQ(response, pong_frame);
565 }
566
TEST_F(WebSocketAcceptingTest,SendPongFrame)567 TEST_F(WebSocketAcceptingTest, SendPongFrame) {
568 TestHttpClient client;
569 CreateConnection(&client);
570 std::string response;
571 client.Send(
572 "GET /test HTTP/1.1\r\n"
573 "Upgrade: WebSocket\r\n"
574 "Connection: SomethingElse, Upgrade\r\n"
575 "Sec-WebSocket-Version: 8\r\n"
576 "Sec-WebSocket-Key: key\r\n\r\n");
577 WaitForRequest();
578 ASSERT_TRUE(client.ReadResponse(&response));
579 const std::string ping_frame = EncodeFrame(
580 /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
581 /* mask= */ true, /* finish= */ true);
582 const std::string pong_frame_send = EncodeFrame(
583 /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
584 /* mask= */ true, /* finish= */ true);
585 const std::string pong_frame_receive = EncodeFrame(
586 /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
587 /* mask= */ false, /* finish= */ true);
588 client.Send(pong_frame_send);
589 client.Send(ping_frame);
590 ASSERT_TRUE(client.Read(&response, pong_frame_receive.length()));
591 EXPECT_EQ(response, pong_frame_receive);
592 }
593
TEST_F(WebSocketAcceptingTest,SendLongTextFrame)594 TEST_F(WebSocketAcceptingTest, SendLongTextFrame) {
595 TestHttpClient client;
596 CreateConnection(&client);
597 std::string response;
598 client.Send(
599 "GET /test HTTP/1.1\r\n"
600 "Upgrade: WebSocket\r\n"
601 "Connection: SomethingElse, Upgrade\r\n"
602 "Sec-WebSocket-Version: 8\r\n"
603 "Sec-WebSocket-Key: key\r\n\r\n");
604 WaitForRequest();
605 ASSERT_TRUE(client.ReadResponse(&response));
606 constexpr int kFrameSize = 100000;
607 const std::string text_frame(kFrameSize, 'a');
608 const std::string continuation_frame(kFrameSize, 'b');
609 const std::string text_encoded_frame =
610 EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
611 /* mask= */ true,
612 /* finish= */ false);
613 const std::string continuation_encoded_frame = EncodeFrame(
614 continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
615 /* mask= */ true, /* finish= */ true);
616 client.Send(text_encoded_frame);
617 client.Send(continuation_encoded_frame);
618 std::string received_message = GetMessage();
619 EXPECT_EQ(received_message.size(),
620 text_frame.size() + continuation_frame.size());
621 EXPECT_EQ(received_message, text_frame + continuation_frame);
622 }
623
TEST_F(WebSocketAcceptingTest,SendTwoTextFrame)624 TEST_F(WebSocketAcceptingTest, SendTwoTextFrame) {
625 TestHttpClient client;
626 CreateConnection(&client);
627 std::string response;
628 client.Send(
629 "GET /test HTTP/1.1\r\n"
630 "Upgrade: WebSocket\r\n"
631 "Connection: SomethingElse, Upgrade\r\n"
632 "Sec-WebSocket-Version: 8\r\n"
633 "Sec-WebSocket-Key: key\r\n\r\n");
634 WaitForRequest();
635 ASSERT_TRUE(client.ReadResponse(&response));
636 const std::string text_frame_first = "foo";
637 const std::string continuation_frame_first = "bar";
638 const std::string text_encoded_frame_first = EncodeFrame(
639 text_frame_first, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
640 /* mask= */ true,
641 /* finish= */ false);
642 const std::string continuation_encoded_frame_first =
643 EncodeFrame(continuation_frame_first,
644 WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
645 /* mask= */ true, /* finish= */ true);
646
647 const std::string text_frame_second = "FOO";
648 const std::string continuation_frame_second = "BAR";
649 const std::string text_encoded_frame_second = EncodeFrame(
650 text_frame_second, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
651 /* mask= */ true,
652 /* finish= */ false);
653 const std::string continuation_encoded_frame_second =
654 EncodeFrame(continuation_frame_second,
655 WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
656 /* mask= */ true, /* finish= */ true);
657
658 // text_encoded_frame_first -> text_encoded_frame_second
659 client.Send(text_encoded_frame_first);
660 client.Send(continuation_encoded_frame_first);
661 std::string received_message = GetMessage();
662 EXPECT_EQ(received_message, "foobar");
663 client.Send(text_encoded_frame_second);
664 client.Send(continuation_encoded_frame_second);
665 received_message = GetMessage();
666 EXPECT_EQ(received_message, "FOOBAR");
667 }
668
TEST_F(WebSocketAcceptingTest,SendPingPongFrame)669 TEST_F(WebSocketAcceptingTest, SendPingPongFrame) {
670 TestHttpClient client;
671 CreateConnection(&client);
672 std::string response;
673 client.Send(
674 "GET /test HTTP/1.1\r\n"
675 "Upgrade: WebSocket\r\n"
676 "Connection: SomethingElse, Upgrade\r\n"
677 "Sec-WebSocket-Version: 8\r\n"
678 "Sec-WebSocket-Key: key\r\n\r\n");
679 WaitForRequest();
680 ASSERT_TRUE(client.ReadResponse(&response));
681
682 const std::string ping_message_first = "";
683 const std::string ping_frame_first = EncodeFrame(
684 ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
685 /* mask= */ true, /* finish= */ true);
686 const std::string pong_frame_receive_first = EncodeFrame(
687 ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
688 /* mask= */ false, /* finish= */ true);
689 const std::string pong_frame_send = EncodeFrame(
690 /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
691 /* mask= */ true, /* finish= */ true);
692 const std::string ping_message_second = "hello";
693 const std::string ping_frame_second = EncodeFrame(
694 ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
695 /* mask= */ true, /* finish= */ true);
696 const std::string pong_frame_receive_second = EncodeFrame(
697 ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
698 /* mask= */ false, /* finish= */ true);
699
700 // ping_frame_first -> pong_frame_send -> ping_frame_second
701 client.Send(ping_frame_first);
702 ASSERT_TRUE(client.Read(&response, pong_frame_receive_first.length()));
703 EXPECT_EQ(response, pong_frame_receive_first);
704 client.Send(pong_frame_send);
705 client.Send(ping_frame_second);
706 ASSERT_TRUE(client.Read(&response, pong_frame_receive_second.length()));
707 EXPECT_EQ(response, pong_frame_receive_second);
708 }
709
TEST_F(WebSocketAcceptingTest,SendTextAndPingFrame)710 TEST_F(WebSocketAcceptingTest, SendTextAndPingFrame) {
711 TestHttpClient client;
712 CreateConnection(&client);
713 std::string response;
714 client.Send(
715 "GET /test HTTP/1.1\r\n"
716 "Upgrade: WebSocket\r\n"
717 "Connection: SomethingElse, Upgrade\r\n"
718 "Sec-WebSocket-Version: 8\r\n"
719 "Sec-WebSocket-Key: key\r\n\r\n");
720 WaitForRequest();
721 ASSERT_TRUE(client.ReadResponse(&response));
722
723 const std::string text_frame = "foo";
724 const std::string continuation_frame = "bar";
725 const std::string text_encoded_frame =
726 EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
727 /* mask= */ true,
728 /* finish= */ false);
729 const std::string continuation_encoded_frame = EncodeFrame(
730 continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
731 /* mask= */ true, /* finish= */ true);
732 const std::string ping_message = "ping";
733 const std::string ping_frame =
734 EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
735 /* mask= */ true, /* finish= */ true);
736 const std::string pong_frame =
737 EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
738 /* mask= */ false, /* finish= */ true);
739
740 // text_encoded_frame -> ping_frame -> continuation_encoded_frame
741 client.Send(text_encoded_frame);
742 client.Send(ping_frame);
743 client.Send(continuation_encoded_frame);
744 ASSERT_TRUE(client.Read(&response, pong_frame.length()));
745 EXPECT_EQ(response, pong_frame);
746 std::string received_message = GetMessage();
747 EXPECT_EQ(received_message, "foobar");
748 }
749
TEST_F(WebSocketAcceptingTest,SendTextAndPingFrameWithMessage)750 TEST_F(WebSocketAcceptingTest, SendTextAndPingFrameWithMessage) {
751 TestHttpClient client;
752 CreateConnection(&client);
753 std::string response;
754 client.Send(
755 "GET /test HTTP/1.1\r\n"
756 "Upgrade: WebSocket\r\n"
757 "Connection: SomethingElse, Upgrade\r\n"
758 "Sec-WebSocket-Version: 8\r\n"
759 "Sec-WebSocket-Key: key\r\n\r\n");
760 WaitForRequest();
761 ASSERT_TRUE(client.ReadResponse(&response));
762
763 const std::string text_frame = "foo";
764 const std::string continuation_frame = "bar";
765 const std::string text_encoded_frame =
766 EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
767 /* mask= */ true,
768 /* finish= */ false);
769 const std::string continuation_encoded_frame = EncodeFrame(
770 continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
771 /* mask= */ true, /* finish= */ true);
772 const std::string ping_message = "hello";
773 const std::string ping_frame =
774 EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
775 /* mask= */ true, /* finish= */ true);
776 const std::string pong_frame =
777 EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
778 /* mask= */ false, /* finish= */ true);
779
780 // text_encoded_frame -> ping_frame -> continuation_frame
781 client.Send(text_encoded_frame);
782 client.Send(ping_frame);
783 client.Send(continuation_encoded_frame);
784 ASSERT_TRUE(client.Read(&response, pong_frame.length()));
785 EXPECT_EQ(response, pong_frame);
786 std::string received_message = GetMessage();
787 EXPECT_EQ(received_message, "foobar");
788 }
789
TEST_F(WebSocketAcceptingTest,SendTextAndPongFrame)790 TEST_F(WebSocketAcceptingTest, SendTextAndPongFrame) {
791 TestHttpClient client;
792 CreateConnection(&client);
793 std::string response;
794 client.Send(
795 "GET /test HTTP/1.1\r\n"
796 "Upgrade: WebSocket\r\n"
797 "Connection: SomethingElse, Upgrade\r\n"
798 "Sec-WebSocket-Version: 8\r\n"
799 "Sec-WebSocket-Key: key\r\n\r\n");
800 WaitForRequest();
801 ASSERT_TRUE(client.ReadResponse(&response));
802
803 const std::string text_frame = "foo";
804 const std::string continuation_frame = "bar";
805 const std::string text_encoded_frame =
806 EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
807 /* mask= */ true,
808 /* finish= */ false);
809 const std::string continuation_encoded_frame = EncodeFrame(
810 continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
811 /* mask= */ true, /* finish= */ true);
812 const std::string pong_message = "pong";
813 const std::string pong_frame =
814 EncodeFrame(pong_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
815 /* mask= */ true, /* finish= */ true);
816
817 // text_encoded_frame -> pong_frame -> continuation_encoded_frame
818 client.Send(text_encoded_frame);
819 client.Send(pong_frame);
820 client.Send(continuation_encoded_frame);
821 std::string received_message = GetMessage();
822 EXPECT_EQ(received_message, "foobar");
823 }
824
TEST_F(WebSocketAcceptingTest,SendTextPingPongFrame)825 TEST_F(WebSocketAcceptingTest, SendTextPingPongFrame) {
826 TestHttpClient client;
827 CreateConnection(&client);
828 std::string response;
829 client.Send(
830 "GET /test HTTP/1.1\r\n"
831 "Upgrade: WebSocket\r\n"
832 "Connection: SomethingElse, Upgrade\r\n"
833 "Sec-WebSocket-Version: 8\r\n"
834 "Sec-WebSocket-Key: key\r\n\r\n");
835 WaitForRequest();
836 ASSERT_TRUE(client.ReadResponse(&response));
837
838 const std::string text_frame = "foo";
839 const std::string continuation_frame = "bar";
840 const std::string text_encoded_frame =
841 EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
842 /* mask= */ true,
843 /* finish= */ false);
844 const std::string continuation_encoded_frame = EncodeFrame(
845 continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
846 /* mask= */ true, /* finish= */ true);
847
848 const std::string ping_message_first = "hello";
849 const std::string ping_frame_first = EncodeFrame(
850 ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
851 /* mask= */ true, /* finish= */ true);
852 const std::string pong_frame_first = EncodeFrame(
853 ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
854 /* mask= */ false, /* finish= */ true);
855
856 const std::string ping_message_second = "HELLO";
857 const std::string ping_frame_second = EncodeFrame(
858 ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
859 /* mask= */ true, /* finish= */ true);
860 const std::string pong_frame_second = EncodeFrame(
861 ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
862 /* mask= */ false, /* finish= */ true);
863
864 // text_encoded_frame -> ping_frame_first -> ping_frame_second ->
865 // continuation_encoded_frame
866 client.Send(text_encoded_frame);
867 client.Send(ping_frame_first);
868 ASSERT_TRUE(client.Read(&response, pong_frame_first.length()));
869 EXPECT_EQ(response, pong_frame_first);
870 client.Send(ping_frame_second);
871 ASSERT_TRUE(client.Read(&response, pong_frame_second.length()));
872 EXPECT_EQ(response, pong_frame_second);
873 client.Send(continuation_encoded_frame);
874 std::string received_message = GetMessage();
875 EXPECT_EQ(received_message, "foobar");
876 }
877
TEST_F(HttpServerTest,RequestWithTooLargeBody)878 TEST_F(HttpServerTest, RequestWithTooLargeBody) {
879 TestHttpClient client;
880 CreateConnection(&client);
881 client.Send(
882 "GET /test HTTP/1.1\r\n"
883 "Content-Length: 1073741824\r\n\r\n");
884 std::string response;
885 ASSERT_TRUE(client.ReadResponse(&response));
886 EXPECT_EQ(
887 "HTTP/1.1 500 Internal Server Error\r\n"
888 "Content-Length:42\r\n"
889 "Content-Type:text/html\r\n\r\n"
890 "request content-length too big or unknown.",
891 response);
892 }
893
TEST_F(HttpServerTest,Send200)894 TEST_F(HttpServerTest, Send200) {
895 TestHttpClient client;
896 CreateConnection(&client);
897 client.Send("GET /test HTTP/1.1\r\n\r\n");
898 auto request = WaitForRequest();
899 server_->Send200(request.connection_id, "Response!", "text/plain",
900 TRAFFIC_ANNOTATION_FOR_TESTS);
901
902 std::string response;
903 ASSERT_TRUE(client.ReadResponse(&response));
904 ASSERT_TRUE(response.starts_with("HTTP/1.1 200 OK"));
905 ASSERT_TRUE(response.ends_with("Response!"));
906 }
907
TEST_F(HttpServerTest,SendRaw)908 TEST_F(HttpServerTest, SendRaw) {
909 TestHttpClient client;
910 CreateConnection(&client);
911 client.Send("GET /test HTTP/1.1\r\n\r\n");
912 auto request = WaitForRequest();
913 server_->SendRaw(request.connection_id, "Raw Data ",
914 TRAFFIC_ANNOTATION_FOR_TESTS);
915 server_->SendRaw(request.connection_id, "More Data",
916 TRAFFIC_ANNOTATION_FOR_TESTS);
917 server_->SendRaw(request.connection_id, "Third Piece of Data",
918 TRAFFIC_ANNOTATION_FOR_TESTS);
919
920 const std::string expected_response("Raw Data More DataThird Piece of Data");
921 std::string response;
922 ASSERT_TRUE(client.Read(&response, expected_response.length()));
923 ASSERT_EQ(expected_response, response);
924 }
925
TEST_F(HttpServerTest,WrongProtocolRequest)926 TEST_F(HttpServerTest, WrongProtocolRequest) {
927 const char* const kBadProtocolRequests[] = {
928 "GET /test HTTP/1.0\r\n\r\n",
929 "GET /test foo\r\n\r\n",
930 "GET /test \r\n\r\n",
931 };
932
933 for (const char* bad_request : kBadProtocolRequests) {
934 TestHttpClient client;
935 CreateConnection(&client);
936
937 client.Send(bad_request);
938 client.ExpectUsedThenDisconnectedWithNoData();
939
940 // Assert that the delegate was updated properly.
941 ASSERT_EQ(1u, connection_map().size());
942 ASSERT_FALSE(connection_map().begin()->second);
943 EXPECT_FALSE(HasRequest());
944
945 // Reset the state of the connection map.
946 connection_map().clear();
947 }
948 }
949
950 class MockStreamSocket : public StreamSocket {
951 public:
952 MockStreamSocket() = default;
953
954 MockStreamSocket(const MockStreamSocket&) = delete;
955 MockStreamSocket& operator=(const MockStreamSocket&) = delete;
956
957 ~MockStreamSocket() override = default;
958
959 // StreamSocket
Connect(CompletionOnceCallback callback)960 int Connect(CompletionOnceCallback callback) override {
961 return ERR_NOT_IMPLEMENTED;
962 }
Disconnect()963 void Disconnect() override {
964 connected_ = false;
965 if (!read_callback_.is_null()) {
966 read_buf_ = nullptr;
967 read_buf_len_ = 0;
968 std::move(read_callback_).Run(ERR_CONNECTION_CLOSED);
969 }
970 }
IsConnected() const971 bool IsConnected() const override { return connected_; }
IsConnectedAndIdle() const972 bool IsConnectedAndIdle() const override { return IsConnected(); }
GetPeerAddress(IPEndPoint * address) const973 int GetPeerAddress(IPEndPoint* address) const override {
974 return ERR_NOT_IMPLEMENTED;
975 }
GetLocalAddress(IPEndPoint * address) const976 int GetLocalAddress(IPEndPoint* address) const override {
977 return ERR_NOT_IMPLEMENTED;
978 }
NetLog() const979 const NetLogWithSource& NetLog() const override { return net_log_; }
WasEverUsed() const980 bool WasEverUsed() const override { return true; }
GetNegotiatedProtocol() const981 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)982 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetTotalReceivedBytes() const983 int64_t GetTotalReceivedBytes() const override {
984 NOTIMPLEMENTED();
985 return 0;
986 }
ApplySocketTag(const SocketTag & tag)987 void ApplySocketTag(const SocketTag& tag) override {}
988
989 // Socket
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)990 int Read(IOBuffer* buf,
991 int buf_len,
992 CompletionOnceCallback callback) override {
993 if (!connected_) {
994 return ERR_SOCKET_NOT_CONNECTED;
995 }
996 if (pending_read_data_.empty()) {
997 read_buf_ = buf;
998 read_buf_len_ = buf_len;
999 read_callback_ = std::move(callback);
1000 return ERR_IO_PENDING;
1001 }
1002 DCHECK_GT(buf_len, 0);
1003 int read_len =
1004 std::min(static_cast<int>(pending_read_data_.size()), buf_len);
1005 memcpy(buf->data(), pending_read_data_.data(), read_len);
1006 pending_read_data_.erase(0, read_len);
1007 return read_len;
1008 }
1009
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1010 int Write(IOBuffer* buf,
1011 int buf_len,
1012 CompletionOnceCallback callback,
1013 const NetworkTrafficAnnotationTag& traffic_annotation) override {
1014 return ERR_NOT_IMPLEMENTED;
1015 }
SetReceiveBufferSize(int32_t size)1016 int SetReceiveBufferSize(int32_t size) override {
1017 return ERR_NOT_IMPLEMENTED;
1018 }
SetSendBufferSize(int32_t size)1019 int SetSendBufferSize(int32_t size) override { return ERR_NOT_IMPLEMENTED; }
1020
DidRead(const char * data,int data_len)1021 void DidRead(const char* data, int data_len) {
1022 if (!read_buf_.get()) {
1023 pending_read_data_.append(data, data_len);
1024 return;
1025 }
1026 int read_len = std::min(data_len, read_buf_len_);
1027 memcpy(read_buf_->data(), data, read_len);
1028 pending_read_data_.assign(data + read_len, data_len - read_len);
1029 read_buf_ = nullptr;
1030 read_buf_len_ = 0;
1031 std::move(read_callback_).Run(read_len);
1032 }
1033
1034 private:
1035 bool connected_ = true;
1036 scoped_refptr<IOBuffer> read_buf_;
1037 int read_buf_len_ = 0;
1038 CompletionOnceCallback read_callback_;
1039 std::string pending_read_data_;
1040 NetLogWithSource net_log_;
1041 };
1042
TEST_F(HttpServerTest,RequestWithBodySplitAcrossPackets)1043 TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
1044 auto socket = std::make_unique<MockStreamSocket>();
1045 auto* socket_ptr = socket.get();
1046 HandleAcceptResult(std::move(socket));
1047 std::string body("body");
1048 std::string request_text = base::StringPrintf(
1049 "GET /test HTTP/1.1\r\n"
1050 "SomeHeader: 1\r\n"
1051 "Content-Length: %" PRIuS "\r\n\r\n%s",
1052 body.length(), body.c_str());
1053 socket_ptr->DidRead(request_text.c_str(), request_text.length() - 2);
1054 ASSERT_FALSE(HasRequest());
1055 socket_ptr->DidRead(request_text.c_str() + request_text.length() - 2, 2);
1056 ASSERT_TRUE(HasRequest());
1057 ASSERT_EQ(body, WaitForRequest().info.data);
1058 }
1059
TEST_F(HttpServerTest,MultipleRequestsOnSameConnection)1060 TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
1061 // The idea behind this test is that requests with or without bodies should
1062 // not break parsing of the next request.
1063 TestHttpClient client;
1064 CreateConnection(&client);
1065 std::string body = "body";
1066 client.Send(
1067 base::StringPrintf("GET /test HTTP/1.1\r\n"
1068 "Content-Length: %" PRIuS "\r\n\r\n%s",
1069 body.length(), body.c_str()));
1070 auto first_request = WaitForRequest();
1071 ASSERT_EQ(body, first_request.info.data);
1072
1073 int client_connection_id = first_request.connection_id;
1074 server_->Send200(client_connection_id, "Content for /test", "text/plain",
1075 TRAFFIC_ANNOTATION_FOR_TESTS);
1076 std::string response1;
1077 ASSERT_TRUE(client.ReadResponse(&response1));
1078 ASSERT_TRUE(response1.starts_with("HTTP/1.1 200 OK"));
1079 ASSERT_TRUE(response1.ends_with("Content for /test"));
1080
1081 client.Send("GET /test2 HTTP/1.1\r\n\r\n");
1082 auto second_request = WaitForRequest();
1083 ASSERT_EQ("/test2", second_request.info.path);
1084
1085 ASSERT_EQ(client_connection_id, second_request.connection_id);
1086 server_->Send404(client_connection_id, TRAFFIC_ANNOTATION_FOR_TESTS);
1087 std::string response2;
1088 ASSERT_TRUE(client.ReadResponse(&response2));
1089 ASSERT_TRUE(response2.starts_with("HTTP/1.1 404 Not Found"));
1090
1091 client.Send("GET /test3 HTTP/1.1\r\n\r\n");
1092 auto third_request = WaitForRequest();
1093 ASSERT_EQ("/test3", third_request.info.path);
1094
1095 ASSERT_EQ(client_connection_id, third_request.connection_id);
1096 server_->Send200(client_connection_id, "Content for /test3", "text/plain",
1097 TRAFFIC_ANNOTATION_FOR_TESTS);
1098 std::string response3;
1099 ASSERT_TRUE(client.ReadResponse(&response3));
1100 ASSERT_TRUE(response3.starts_with("HTTP/1.1 200 OK"));
1101 ASSERT_TRUE(response3.ends_with("Content for /test3"));
1102 }
1103
1104 class CloseOnConnectHttpServerTest : public HttpServerTest {
1105 public:
OnConnect(int connection_id)1106 void OnConnect(int connection_id) override {
1107 HttpServerTest::OnConnect(connection_id);
1108 connection_ids_.push_back(connection_id);
1109 server_->Close(connection_id);
1110 }
1111
1112 protected:
1113 std::vector<int> connection_ids_;
1114 };
1115
TEST_F(CloseOnConnectHttpServerTest,ServerImmediatelyClosesConnection)1116 TEST_F(CloseOnConnectHttpServerTest, ServerImmediatelyClosesConnection) {
1117 TestHttpClient client;
1118 CreateConnection(&client);
1119 client.Send("GET / HTTP/1.1\r\n\r\n");
1120
1121 // The server should close the socket without responding.
1122 client.ExpectUsedThenDisconnectedWithNoData();
1123
1124 // Run any tasks the TestServer posted.
1125 base::RunLoop().RunUntilIdle();
1126
1127 EXPECT_EQ(1ul, connection_ids_.size());
1128 // OnHttpRequest() should never have been called, since the connection was
1129 // closed without reading from it.
1130 EXPECT_FALSE(HasRequest());
1131 }
1132
1133 } // namespace
1134
1135 } // namespace net
1136