xref: /aosp_15_r20/external/cronet/net/test/embedded_test_server/embedded_test_server.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/test/embedded_test_server/embedded_test_server.h"
6 
7 #include <stdint.h>
8 
9 #include <memory>
10 #include <optional>
11 #include <string_view>
12 #include <utility>
13 
14 #include "base/files/file_path.h"
15 #include "base/functional/bind.h"
16 #include "base/functional/callback_forward.h"
17 #include "base/functional/callback_helpers.h"
18 #include "base/location.h"
19 #include "base/logging.h"
20 #include "base/message_loop/message_pump_type.h"
21 #include "base/path_service.h"
22 #include "base/process/process_metrics.h"
23 #include "base/run_loop.h"
24 #include "base/strings/string_number_conversions.h"
25 #include "base/strings/string_util.h"
26 #include "base/strings/stringprintf.h"
27 #include "base/task/current_thread.h"
28 #include "base/task/single_thread_task_executor.h"
29 #include "base/task/single_thread_task_runner.h"
30 #include "base/test/bind.h"
31 #include "base/threading/thread_restrictions.h"
32 #include "crypto/rsa_private_key.h"
33 #include "net/base/hex_utils.h"
34 #include "net/base/ip_endpoint.h"
35 #include "net/base/net_errors.h"
36 #include "net/base/port_util.h"
37 #include "net/log/net_log_source.h"
38 #include "net/socket/next_proto.h"
39 #include "net/socket/ssl_server_socket.h"
40 #include "net/socket/stream_socket.h"
41 #include "net/socket/tcp_server_socket.h"
42 #include "net/spdy/spdy_test_util_common.h"
43 #include "net/ssl/ssl_info.h"
44 #include "net/ssl/ssl_server_config.h"
45 #include "net/test/cert_builder.h"
46 #include "net/test/cert_test_util.h"
47 #include "net/test/embedded_test_server/default_handlers.h"
48 #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h"
49 #include "net/test/embedded_test_server/http_request.h"
50 #include "net/test/embedded_test_server/http_response.h"
51 #include "net/test/embedded_test_server/request_handler_util.h"
52 #include "net/test/key_util.h"
53 #include "net/test/revocation_builder.h"
54 #include "net/test/test_data_directory.h"
55 #include "net/third_party/quiche/src/quiche/spdy/core/spdy_frame_builder.h"
56 #include "third_party/boringssl/src/pki/extended_key_usage.h"
57 #include "url/origin.h"
58 
59 namespace net::test_server {
60 
61 namespace {
62 
ServeResponseForPath(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)63 std::unique_ptr<HttpResponse> ServeResponseForPath(
64     const std::string& expected_path,
65     HttpStatusCode status_code,
66     const std::string& content_type,
67     const std::string& content,
68     const HttpRequest& request) {
69   if (request.GetURL().path() != expected_path)
70     return nullptr;
71 
72   auto http_response = std::make_unique<BasicHttpResponse>();
73   http_response->set_code(status_code);
74   http_response->set_content_type(content_type);
75   http_response->set_content(content);
76   return http_response;
77 }
78 
79 // Serves response for |expected_path| or any subpath of it.
80 // |expected_path| should not include a trailing "/".
ServeResponseForSubPaths(const std::string & expected_path,HttpStatusCode status_code,const std::string & content_type,const std::string & content,const HttpRequest & request)81 std::unique_ptr<HttpResponse> ServeResponseForSubPaths(
82     const std::string& expected_path,
83     HttpStatusCode status_code,
84     const std::string& content_type,
85     const std::string& content,
86     const HttpRequest& request) {
87   if (request.GetURL().path() != expected_path &&
88       !request.GetURL().path().starts_with(expected_path + "/")) {
89     return nullptr;
90   }
91 
92   auto http_response = std::make_unique<BasicHttpResponse>();
93   http_response->set_code(status_code);
94   http_response->set_content_type(content_type);
95   http_response->set_content(content);
96   return http_response;
97 }
98 
MaybeCreateOCSPResponse(CertBuilder * target,const EmbeddedTestServer::OCSPConfig & config,std::string * out_response)99 bool MaybeCreateOCSPResponse(CertBuilder* target,
100                              const EmbeddedTestServer::OCSPConfig& config,
101                              std::string* out_response) {
102   using OCSPResponseType = EmbeddedTestServer::OCSPConfig::ResponseType;
103 
104   if (!config.single_responses.empty() &&
105       config.response_type != OCSPResponseType::kSuccessful) {
106     // OCSPConfig contained single_responses for a non-successful response.
107     return false;
108   }
109 
110   if (config.response_type == OCSPResponseType::kOff) {
111     *out_response = std::string();
112     return true;
113   }
114 
115   if (!target) {
116     // OCSPConfig enabled but corresponding certificate is null.
117     return false;
118   }
119 
120   switch (config.response_type) {
121     case OCSPResponseType::kOff:
122       return false;
123     case OCSPResponseType::kMalformedRequest:
124       *out_response = BuildOCSPResponseError(
125           bssl::OCSPResponse::ResponseStatus::MALFORMED_REQUEST);
126       return true;
127     case OCSPResponseType::kInternalError:
128       *out_response = BuildOCSPResponseError(
129           bssl::OCSPResponse::ResponseStatus::INTERNAL_ERROR);
130       return true;
131     case OCSPResponseType::kTryLater:
132       *out_response =
133           BuildOCSPResponseError(bssl::OCSPResponse::ResponseStatus::TRY_LATER);
134       return true;
135     case OCSPResponseType::kSigRequired:
136       *out_response = BuildOCSPResponseError(
137           bssl::OCSPResponse::ResponseStatus::SIG_REQUIRED);
138       return true;
139     case OCSPResponseType::kUnauthorized:
140       *out_response = BuildOCSPResponseError(
141           bssl::OCSPResponse::ResponseStatus::UNAUTHORIZED);
142       return true;
143     case OCSPResponseType::kInvalidResponse:
144       *out_response = "3";
145       return true;
146     case OCSPResponseType::kInvalidResponseData:
147       *out_response =
148           BuildOCSPResponseWithResponseData(target->issuer()->GetKey(),
149                                             // OCTET_STRING { "not ocsp data" }
150                                             "\x04\x0dnot ocsp data");
151       return true;
152     case OCSPResponseType::kSuccessful:
153       break;
154   }
155 
156   base::Time now = base::Time::Now();
157   base::Time target_not_before, target_not_after;
158   if (!target->GetValidity(&target_not_before, &target_not_after))
159     return false;
160   base::Time produced_at;
161   using OCSPProduced = EmbeddedTestServer::OCSPConfig::Produced;
162   switch (config.produced) {
163     case OCSPProduced::kValid:
164       produced_at = std::max(now - base::Days(1), target_not_before);
165       break;
166     case OCSPProduced::kBeforeCert:
167       produced_at = target_not_before - base::Days(1);
168       break;
169     case OCSPProduced::kAfterCert:
170       produced_at = target_not_after + base::Days(1);
171       break;
172   }
173 
174   std::vector<OCSPBuilderSingleResponse> responses;
175   for (const auto& config_response : config.single_responses) {
176     OCSPBuilderSingleResponse response;
177     response.serial = target->GetSerialNumber();
178     if (config_response.serial ==
179         EmbeddedTestServer::OCSPConfig::SingleResponse::Serial::kMismatch) {
180       response.serial ^= 1;
181     }
182     response.cert_status = config_response.cert_status;
183     // |revocation_time| is ignored if |cert_status| is not REVOKED.
184     response.revocation_time = now - base::Days(1000);
185 
186     using OCSPDate = EmbeddedTestServer::OCSPConfig::SingleResponse::Date;
187     switch (config_response.ocsp_date) {
188       case OCSPDate::kValid:
189         response.this_update = now - base::Days(1);
190         response.next_update = response.this_update + base::Days(7);
191         break;
192       case OCSPDate::kOld:
193         response.this_update = now - base::Days(8);
194         response.next_update = response.this_update + base::Days(7);
195         break;
196       case OCSPDate::kEarly:
197         response.this_update = now + base::Days(1);
198         response.next_update = response.this_update + base::Days(7);
199         break;
200       case OCSPDate::kLong:
201         response.this_update = now - base::Days(365);
202         response.next_update = response.this_update + base::Days(366);
203         break;
204       case OCSPDate::kLonger:
205         response.this_update = now - base::Days(367);
206         response.next_update = response.this_update + base::Days(368);
207         break;
208     }
209 
210     responses.push_back(response);
211   }
212   *out_response =
213       BuildOCSPResponse(target->issuer()->GetSubject(),
214                         target->issuer()->GetKey(), produced_at, responses);
215   return true;
216 }
217 
218 }  // namespace
219 
EmbeddedTestServerHandle(EmbeddedTestServerHandle && other)220 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
221     EmbeddedTestServerHandle&& other) {
222   operator=(std::move(other));
223 }
224 
operator =(EmbeddedTestServerHandle && other)225 EmbeddedTestServerHandle& EmbeddedTestServerHandle::operator=(
226     EmbeddedTestServerHandle&& other) {
227   EmbeddedTestServerHandle temporary;
228   std::swap(other.test_server_, temporary.test_server_);
229   std::swap(temporary.test_server_, test_server_);
230   return *this;
231 }
232 
EmbeddedTestServerHandle(EmbeddedTestServer * test_server)233 EmbeddedTestServerHandle::EmbeddedTestServerHandle(
234     EmbeddedTestServer* test_server)
235     : test_server_(test_server) {}
236 
~EmbeddedTestServerHandle()237 EmbeddedTestServerHandle::~EmbeddedTestServerHandle() {
238   if (test_server_)
239     CHECK(test_server_->ShutdownAndWaitUntilComplete());
240 }
241 
242 EmbeddedTestServer::OCSPConfig::OCSPConfig() = default;
OCSPConfig(ResponseType response_type)243 EmbeddedTestServer::OCSPConfig::OCSPConfig(ResponseType response_type)
244     : response_type(response_type) {}
OCSPConfig(std::vector<SingleResponse> single_responses,Produced produced)245 EmbeddedTestServer::OCSPConfig::OCSPConfig(
246     std::vector<SingleResponse> single_responses,
247     Produced produced)
248     : response_type(ResponseType::kSuccessful),
249       produced(produced),
250       single_responses(std::move(single_responses)) {}
251 EmbeddedTestServer::OCSPConfig::OCSPConfig(const OCSPConfig&) = default;
252 EmbeddedTestServer::OCSPConfig::OCSPConfig(OCSPConfig&&) = default;
253 EmbeddedTestServer::OCSPConfig::~OCSPConfig() = default;
254 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
255     const OCSPConfig&) = default;
256 EmbeddedTestServer::OCSPConfig& EmbeddedTestServer::OCSPConfig::operator=(
257     OCSPConfig&&) = default;
258 
259 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig() =
260     default;
261 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
262     const ServerCertificateConfig&) = default;
263 EmbeddedTestServer::ServerCertificateConfig::ServerCertificateConfig(
264     ServerCertificateConfig&&) = default;
265 EmbeddedTestServer::ServerCertificateConfig::~ServerCertificateConfig() =
266     default;
267 EmbeddedTestServer::ServerCertificateConfig&
268 EmbeddedTestServer::ServerCertificateConfig::operator=(
269     const ServerCertificateConfig&) = default;
270 EmbeddedTestServer::ServerCertificateConfig&
271 EmbeddedTestServer::ServerCertificateConfig::operator=(
272     ServerCertificateConfig&&) = default;
273 
EmbeddedTestServer()274 EmbeddedTestServer::EmbeddedTestServer() : EmbeddedTestServer(TYPE_HTTP) {}
275 
EmbeddedTestServer(Type type,HttpConnection::Protocol protocol)276 EmbeddedTestServer::EmbeddedTestServer(Type type,
277                                        HttpConnection::Protocol protocol)
278     : is_using_ssl_(type == TYPE_HTTPS), protocol_(protocol) {
279   DCHECK(thread_checker_.CalledOnValidThread());
280   // HTTP/2 is only valid by negotiation via TLS ALPN
281   DCHECK(protocol_ != HttpConnection::Protocol::kHttp2 || type == TYPE_HTTPS);
282 
283   if (!is_using_ssl_)
284     return;
285   scoped_test_root_ = RegisterTestCerts();
286 }
287 
~EmbeddedTestServer()288 EmbeddedTestServer::~EmbeddedTestServer() {
289   DCHECK(thread_checker_.CalledOnValidThread());
290 
291   if (Started())
292     CHECK(ShutdownAndWaitUntilComplete());
293 
294   {
295     base::ScopedAllowBaseSyncPrimitivesForTesting allow_wait_for_thread_join;
296     io_thread_.reset();
297   }
298 }
299 
RegisterTestCerts()300 ScopedTestRoot EmbeddedTestServer::RegisterTestCerts() {
301   base::ScopedAllowBlockingForTesting allow_blocking;
302   auto root = ImportCertFromFile(GetRootCertPemPath());
303   if (!root)
304     return ScopedTestRoot();
305   return ScopedTestRoot(root);
306 }
307 
SetConnectionListener(EmbeddedTestServerConnectionListener * listener)308 void EmbeddedTestServer::SetConnectionListener(
309     EmbeddedTestServerConnectionListener* listener) {
310   DCHECK(!io_thread_)
311       << "ConnectionListener must be set before starting the server.";
312   connection_listener_ = listener;
313 }
314 
StartAndReturnHandle(int port)315 EmbeddedTestServerHandle EmbeddedTestServer::StartAndReturnHandle(int port) {
316   bool result = Start(port);
317   return result ? EmbeddedTestServerHandle(this) : EmbeddedTestServerHandle();
318 }
319 
Start(int port,std::string_view address)320 bool EmbeddedTestServer::Start(int port, std::string_view address) {
321   bool success = InitializeAndListen(port, address);
322   if (success)
323     StartAcceptingConnections();
324   return success;
325 }
326 
InitializeAndListen(int port,std::string_view address)327 bool EmbeddedTestServer::InitializeAndListen(int port,
328                                              std::string_view address) {
329   DCHECK(!Started());
330 
331   const int max_tries = 5;
332   int num_tries = 0;
333   bool is_valid_port = false;
334 
335   do {
336     if (++num_tries > max_tries) {
337       LOG(ERROR) << "Failed to listen on a valid port after " << max_tries
338                  << " attempts.";
339       listen_socket_.reset();
340       return false;
341     }
342 
343     listen_socket_ = std::make_unique<TCPServerSocket>(nullptr, NetLogSource());
344 
345     int result =
346         listen_socket_->ListenWithAddressAndPort(address.data(), port, 10);
347     if (result) {
348       LOG(ERROR) << "Listen failed: " << ErrorToString(result);
349       listen_socket_.reset();
350       return false;
351     }
352 
353     result = listen_socket_->GetLocalAddress(&local_endpoint_);
354     if (result != OK) {
355       LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
356       listen_socket_.reset();
357       return false;
358     }
359 
360     port_ = local_endpoint_.port();
361     is_valid_port |= net::IsPortAllowedForScheme(
362         port_, is_using_ssl_ ? url::kHttpsScheme : url::kHttpScheme);
363   } while (!is_valid_port);
364 
365   if (is_using_ssl_) {
366     base_url_ = GURL("https://" + local_endpoint_.ToString());
367     if (cert_ == CERT_MISMATCHED_NAME || cert_ == CERT_COMMON_NAME_IS_DOMAIN) {
368       base_url_ = GURL(
369           base::StringPrintf("https://localhost:%d", local_endpoint_.port()));
370     }
371   } else {
372     base_url_ = GURL("http://" + local_endpoint_.ToString());
373   }
374 
375   listen_socket_->DetachFromThread();
376 
377   if (is_using_ssl_ && !InitializeSSLServerContext())
378     return false;
379 
380   return true;
381 }
382 
UsingStaticCert() const383 bool EmbeddedTestServer::UsingStaticCert() const {
384   return !GetCertificateName().empty();
385 }
386 
InitializeCertAndKeyFromFile()387 bool EmbeddedTestServer::InitializeCertAndKeyFromFile() {
388   base::ScopedAllowBlockingForTesting allow_blocking;
389   base::FilePath certs_dir(GetTestCertsDirectory());
390   std::string cert_name = GetCertificateName();
391   if (cert_name.empty())
392     return false;
393 
394   x509_cert_ = CreateCertificateChainFromFile(certs_dir, cert_name,
395                                               X509Certificate::FORMAT_AUTO);
396   if (!x509_cert_)
397     return false;
398 
399   private_key_ =
400       key_util::LoadEVP_PKEYFromPEM(certs_dir.AppendASCII(cert_name));
401   return !!private_key_;
402 }
403 
GenerateCertAndKey()404 bool EmbeddedTestServer::GenerateCertAndKey() {
405   // Create AIA server and start listening. Need to have the socket initialized
406   // so the URL can be put in the AIA records of the generated certs.
407   aia_http_server_ = std::make_unique<EmbeddedTestServer>(TYPE_HTTP);
408   if (!aia_http_server_->InitializeAndListen())
409     return false;
410 
411   base::ScopedAllowBlockingForTesting allow_blocking;
412   base::FilePath certs_dir(GetTestCertsDirectory());
413 
414   std::unique_ptr<CertBuilder> static_root = CertBuilder::FromStaticCertFile(
415       certs_dir.AppendASCII("root_ca_cert.pem"));
416 
417   auto now = base::Time::Now();
418   // Will be nullptr if cert_config_.intermediate == kNone.
419   std::unique_ptr<CertBuilder> intermediate;
420   std::unique_ptr<CertBuilder> leaf;
421 
422   if (cert_config_.intermediate != IntermediateType::kNone) {
423     intermediate = CertBuilder::FromFile(
424         certs_dir.AppendASCII("intermediate_ca_cert.pem"), static_root.get());
425     if (!intermediate)
426       return false;
427     intermediate->SetValidity(now - base::Days(100), now + base::Days(1000));
428 
429     leaf = CertBuilder::FromFile(certs_dir.AppendASCII("ok_cert.pem"),
430                                  intermediate.get());
431   } else {
432     leaf = CertBuilder::FromFile(certs_dir.AppendASCII("ok_cert.pem"),
433                                  static_root.get());
434   }
435   if (!leaf)
436     return false;
437 
438   std::vector<GURL> leaf_ca_issuers_urls;
439   std::vector<GURL> leaf_ocsp_urls;
440 
441   leaf->SetValidity(now - base::Days(1), now + base::Days(20));
442 
443   if (!cert_config_.policy_oids.empty()) {
444     leaf->SetCertificatePolicies(cert_config_.policy_oids);
445     if (intermediate)
446       intermediate->SetCertificatePolicies(cert_config_.policy_oids);
447   }
448 
449   if (!cert_config_.dns_names.empty() || !cert_config_.ip_addresses.empty()) {
450     leaf->SetSubjectAltNames(cert_config_.dns_names, cert_config_.ip_addresses);
451   }
452 
453   if (!cert_config_.key_usages.empty()) {
454     leaf->SetKeyUsages(cert_config_.key_usages);
455   }
456 
457   if (!cert_config_.embedded_scts.empty()) {
458     leaf->SetSctConfig(cert_config_.embedded_scts);
459   }
460 
461   const std::string leaf_serial_text =
462       base::NumberToString(leaf->GetSerialNumber());
463   const std::string intermediate_serial_text =
464       intermediate ? base::NumberToString(intermediate->GetSerialNumber()) : "";
465 
466   std::string ocsp_response;
467   if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.ocsp_config,
468                                &ocsp_response)) {
469     return false;
470   }
471   if (!ocsp_response.empty()) {
472     std::string ocsp_path = "/ocsp/" + leaf_serial_text;
473     leaf_ocsp_urls.push_back(aia_http_server_->GetURL(ocsp_path));
474     aia_http_server_->RegisterRequestHandler(
475         base::BindRepeating(ServeResponseForSubPaths, ocsp_path, HTTP_OK,
476                             "application/ocsp-response", ocsp_response));
477   }
478 
479   std::string stapled_ocsp_response;
480   if (!MaybeCreateOCSPResponse(leaf.get(), cert_config_.stapled_ocsp_config,
481                                &stapled_ocsp_response)) {
482     return false;
483   }
484   if (!stapled_ocsp_response.empty()) {
485     ssl_config_.ocsp_response = std::vector<uint8_t>(
486         stapled_ocsp_response.begin(), stapled_ocsp_response.end());
487   }
488 
489   std::string intermediate_ocsp_response;
490   if (!MaybeCreateOCSPResponse(intermediate.get(),
491                                cert_config_.intermediate_ocsp_config,
492                                &intermediate_ocsp_response)) {
493     return false;
494   }
495   if (!intermediate_ocsp_response.empty()) {
496     std::string intermediate_ocsp_path = "/ocsp/" + intermediate_serial_text;
497     intermediate->SetCaIssuersAndOCSPUrls(
498         {}, {aia_http_server_->GetURL(intermediate_ocsp_path)});
499     aia_http_server_->RegisterRequestHandler(base::BindRepeating(
500         ServeResponseForSubPaths, intermediate_ocsp_path, HTTP_OK,
501         "application/ocsp-response", intermediate_ocsp_response));
502   }
503 
504   if (cert_config_.intermediate == IntermediateType::kByAIA) {
505     std::string ca_issuers_path = "/ca_issuers/" + intermediate_serial_text;
506     leaf_ca_issuers_urls.push_back(aia_http_server_->GetURL(ca_issuers_path));
507 
508     // Setup AIA server to serve the intermediate referred to by the leaf.
509     aia_http_server_->RegisterRequestHandler(
510         base::BindRepeating(ServeResponseForPath, ca_issuers_path, HTTP_OK,
511                             "application/pkix-cert", intermediate->GetDER()));
512   }
513 
514   if (!leaf_ca_issuers_urls.empty() || !leaf_ocsp_urls.empty()) {
515     leaf->SetCaIssuersAndOCSPUrls(leaf_ca_issuers_urls, leaf_ocsp_urls);
516   }
517 
518   if (cert_config_.intermediate == IntermediateType::kByAIA ||
519       cert_config_.intermediate == IntermediateType::kMissing) {
520     // Server certificate chain does not include the intermediate.
521     x509_cert_ = leaf->GetX509Certificate();
522   } else {
523     // Server certificate chain will include the intermediate, if there is one.
524     x509_cert_ = leaf->GetX509CertificateChain();
525   }
526 
527   if (intermediate) {
528     intermediate_ = intermediate->GetX509Certificate();
529   }
530 
531   private_key_ = bssl::UpRef(leaf->GetKey());
532 
533   // If this server is already accepting connections but is being reconfigured,
534   // start the new AIA server now. Otherwise, wait until
535   // StartAcceptingConnections so that this server and the AIA server start at
536   // the same time. (If the test only called InitializeAndListen they expect no
537   // threads to be created yet.)
538   if (io_thread_)
539     aia_http_server_->StartAcceptingConnections();
540 
541   return true;
542 }
543 
InitializeSSLServerContext()544 bool EmbeddedTestServer::InitializeSSLServerContext() {
545   if (UsingStaticCert()) {
546     if (!InitializeCertAndKeyFromFile())
547       return false;
548   } else {
549     if (!GenerateCertAndKey())
550       return false;
551   }
552 
553   if (protocol_ == HttpConnection::Protocol::kHttp2) {
554     ssl_config_.alpn_protos = {NextProto::kProtoHTTP2};
555     if (!alps_accept_ch_.empty()) {
556       base::StringPairs origin_accept_ch;
557       size_t frame_size = spdy::kFrameHeaderSize;
558       // Figure out size and generate origins
559       for (const auto& pair : alps_accept_ch_) {
560         std::string_view hostname = pair.first;
561         std::string accept_ch = pair.second;
562 
563         GURL url = hostname.empty() ? GetURL("/") : GetURL(hostname, "/");
564         std::string origin = url::Origin::Create(url).Serialize();
565 
566         frame_size += accept_ch.size() + origin.size() +
567                       (sizeof(uint16_t) * 2);  // = Origin-Len + Value-Len
568 
569         origin_accept_ch.push_back({std::move(origin), std::move(accept_ch)});
570       }
571 
572       spdy::SpdyFrameBuilder builder(frame_size);
573       builder.BeginNewFrame(spdy::SpdyFrameType::ACCEPT_CH, 0, 0);
574       for (const auto& pair : origin_accept_ch) {
575         std::string_view origin = pair.first;
576         std::string_view accept_ch = pair.second;
577 
578         builder.WriteUInt16(origin.size());
579         builder.WriteBytes(origin.data(), origin.size());
580 
581         builder.WriteUInt16(accept_ch.size());
582         builder.WriteBytes(accept_ch.data(), accept_ch.size());
583       }
584 
585       spdy::SpdySerializedFrame serialized_frame = builder.take();
586       DCHECK_EQ(frame_size, serialized_frame.size());
587 
588       ssl_config_.application_settings[NextProto::kProtoHTTP2] =
589           std::vector<uint8_t>(
590               serialized_frame.data(),
591               serialized_frame.data() + serialized_frame.size());
592 
593       ssl_config_.client_hello_callback_for_testing =
594           base::BindRepeating([](const SSL_CLIENT_HELLO* client_hello) {
595             // Configure the server to use the ALPS codepoint that the client
596             // offered.
597             const uint8_t* unused_extension_bytes;
598             size_t unused_extension_len;
599             int use_alps_new_codepoint = SSL_early_callback_ctx_extension_get(
600                 client_hello, TLSEXT_TYPE_application_settings,
601                 &unused_extension_bytes, &unused_extension_len);
602             // Make sure we use the right ALPS codepoint.
603             SSL_set_alps_use_new_codepoint(client_hello->ssl,
604                                            use_alps_new_codepoint);
605             return true;
606           });
607     }
608   }
609 
610   context_ =
611       CreateSSLServerContext(x509_cert_.get(), private_key_.get(), ssl_config_);
612   return true;
613 }
614 
615 EmbeddedTestServerHandle
StartAcceptingConnectionsAndReturnHandle()616 EmbeddedTestServer::StartAcceptingConnectionsAndReturnHandle() {
617   StartAcceptingConnections();
618   return EmbeddedTestServerHandle(this);
619 }
620 
StartAcceptingConnections()621 void EmbeddedTestServer::StartAcceptingConnections() {
622   DCHECK(Started());
623   DCHECK(!io_thread_) << "Server must not be started while server is running";
624 
625   if (aia_http_server_)
626     aia_http_server_->StartAcceptingConnections();
627 
628   base::Thread::Options thread_options;
629   thread_options.message_pump_type = base::MessagePumpType::IO;
630   io_thread_ = std::make_unique<base::Thread>("EmbeddedTestServer IO Thread");
631   CHECK(io_thread_->StartWithOptions(std::move(thread_options)));
632   CHECK(io_thread_->WaitUntilThreadStarted());
633 
634   io_thread_->task_runner()->PostTask(
635       FROM_HERE, base::BindOnce(&EmbeddedTestServer::DoAcceptLoop,
636                                 base::Unretained(this)));
637 }
638 
ShutdownAndWaitUntilComplete()639 bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
640   DCHECK(thread_checker_.CalledOnValidThread());
641 
642   if (!io_thread_) {
643     // Can't stop a server that never started.
644     return true;
645   }
646 
647   // Ensure that the AIA HTTP server is no longer Started().
648   bool aia_http_server_not_started = true;
649   if (aia_http_server_ && aia_http_server_->Started()) {
650     aia_http_server_not_started =
651         aia_http_server_->ShutdownAndWaitUntilComplete();
652   }
653 
654   // Return false if either this or the AIA HTTP server are still Started().
655   return PostTaskToIOThreadAndWait(
656              base::BindOnce(&EmbeddedTestServer::ShutdownOnIOThread,
657                             base::Unretained(this))) &&
658          aia_http_server_not_started;
659 }
660 
661 // static
GetRootCertPemPath()662 base::FilePath EmbeddedTestServer::GetRootCertPemPath() {
663   return GetTestCertsDirectory().AppendASCII("root_ca_cert.pem");
664 }
665 
ShutdownOnIOThread()666 void EmbeddedTestServer::ShutdownOnIOThread() {
667   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
668   weak_factory_.InvalidateWeakPtrs();
669   listen_socket_.reset();
670   connections_.clear();
671 }
672 
HandleRequest(base::WeakPtr<HttpResponseDelegate> delegate,std::unique_ptr<HttpRequest> request)673 void EmbeddedTestServer::HandleRequest(
674     base::WeakPtr<HttpResponseDelegate> delegate,
675     std::unique_ptr<HttpRequest> request) {
676   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
677   request->base_url = base_url_;
678 
679   for (const auto& monitor : request_monitors_)
680     monitor.Run(*request);
681 
682   std::unique_ptr<HttpResponse> response;
683 
684   for (const auto& handler : request_handlers_) {
685     response = handler.Run(*request);
686     if (response)
687       break;
688   }
689 
690   if (!response) {
691     for (const auto& handler : default_request_handlers_) {
692       response = handler.Run(*request);
693       if (response)
694         break;
695     }
696   }
697 
698   if (!response) {
699     LOG(WARNING) << "Request not handled. Returning 404: "
700                  << request->relative_url;
701     auto not_found_response = std::make_unique<BasicHttpResponse>();
702     not_found_response->set_code(HTTP_NOT_FOUND);
703     response = std::move(not_found_response);
704   }
705 
706   HttpResponse* const response_ptr = response.get();
707   delegate->AddResponse(std::move(response));
708   response_ptr->SendResponse(delegate);
709 }
710 
GetURL(std::string_view relative_url) const711 GURL EmbeddedTestServer::GetURL(std::string_view relative_url) const {
712   DCHECK(Started()) << "You must start the server first.";
713   DCHECK(relative_url.starts_with("/")) << relative_url;
714   return base_url_.Resolve(relative_url);
715 }
716 
GetURL(std::string_view hostname,std::string_view relative_url) const717 GURL EmbeddedTestServer::GetURL(std::string_view hostname,
718                                 std::string_view relative_url) const {
719   GURL local_url = GetURL(relative_url);
720   GURL::Replacements replace_host;
721   replace_host.SetHostStr(hostname);
722   return local_url.ReplaceComponents(replace_host);
723 }
724 
GetOrigin(const std::optional<std::string> & hostname) const725 url::Origin EmbeddedTestServer::GetOrigin(
726     const std::optional<std::string>& hostname) const {
727   if (hostname)
728     return url::Origin::Create(GetURL(*hostname, "/"));
729   return url::Origin::Create(base_url_);
730 }
731 
GetAddressList(AddressList * address_list) const732 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const {
733   *address_list = AddressList(local_endpoint_);
734   return true;
735 }
736 
GetIPLiteralString() const737 std::string EmbeddedTestServer::GetIPLiteralString() const {
738   return local_endpoint_.address().ToString();
739 }
740 
SetSSLConfigInternal(ServerCertificate cert,const ServerCertificateConfig * cert_config,const SSLServerConfig & ssl_config)741 void EmbeddedTestServer::SetSSLConfigInternal(
742     ServerCertificate cert,
743     const ServerCertificateConfig* cert_config,
744     const SSLServerConfig& ssl_config) {
745   DCHECK(!Started());
746   cert_ = cert;
747   DCHECK(!cert_config || cert == CERT_AUTO);
748   cert_config_ = cert_config ? *cert_config : ServerCertificateConfig();
749   x509_cert_ = nullptr;
750   private_key_ = nullptr;
751   ssl_config_ = ssl_config;
752 }
753 
SetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)754 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert,
755                                       const SSLServerConfig& ssl_config) {
756   SetSSLConfigInternal(cert, /*cert_config=*/nullptr, ssl_config);
757 }
758 
SetSSLConfig(ServerCertificate cert)759 void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert) {
760   SetSSLConfigInternal(cert, /*cert_config=*/nullptr, SSLServerConfig());
761 }
762 
SetSSLConfig(const ServerCertificateConfig & cert_config,const SSLServerConfig & ssl_config)763 void EmbeddedTestServer::SetSSLConfig(
764     const ServerCertificateConfig& cert_config,
765     const SSLServerConfig& ssl_config) {
766   SetSSLConfigInternal(CERT_AUTO, &cert_config, ssl_config);
767 }
768 
SetSSLConfig(const ServerCertificateConfig & cert_config)769 void EmbeddedTestServer::SetSSLConfig(
770     const ServerCertificateConfig& cert_config) {
771   SetSSLConfigInternal(CERT_AUTO, &cert_config, SSLServerConfig());
772 }
773 
SetCertHostnames(std::vector<std::string> hostnames)774 void EmbeddedTestServer::SetCertHostnames(std::vector<std::string> hostnames) {
775   ServerCertificateConfig cert_config;
776   cert_config.dns_names = std::move(hostnames);
777   cert_config.ip_addresses = {net::IPAddress::IPv4Localhost()};
778   SetSSLConfig(cert_config);
779 }
780 
ResetSSLConfigOnIOThread(ServerCertificate cert,const SSLServerConfig & ssl_config)781 bool EmbeddedTestServer::ResetSSLConfigOnIOThread(
782     ServerCertificate cert,
783     const SSLServerConfig& ssl_config) {
784   cert_ = cert;
785   cert_config_ = ServerCertificateConfig();
786   ssl_config_ = ssl_config;
787   connections_.clear();
788   return InitializeSSLServerContext();
789 }
790 
ResetSSLConfig(ServerCertificate cert,const SSLServerConfig & ssl_config)791 bool EmbeddedTestServer::ResetSSLConfig(ServerCertificate cert,
792                                         const SSLServerConfig& ssl_config) {
793   return PostTaskToIOThreadAndWaitWithResult(
794       base::BindOnce(&EmbeddedTestServer::ResetSSLConfigOnIOThread,
795                      base::Unretained(this), cert, ssl_config));
796 }
797 
GetCertificateName() const798 std::string EmbeddedTestServer::GetCertificateName() const {
799   DCHECK(is_using_ssl_);
800   switch (cert_) {
801     case CERT_OK:
802     case CERT_MISMATCHED_NAME:
803       return "ok_cert.pem";
804     case CERT_COMMON_NAME_IS_DOMAIN:
805       return "localhost_cert.pem";
806     case CERT_EXPIRED:
807       return "expired_cert.pem";
808     case CERT_CHAIN_WRONG_ROOT:
809       // This chain uses its own dedicated test root certificate to avoid
810       // side-effects that may affect testing.
811       return "redundant-server-chain.pem";
812     case CERT_COMMON_NAME_ONLY:
813       return "common_name_only.pem";
814     case CERT_SHA1_LEAF:
815       return "sha1_leaf.pem";
816     case CERT_OK_BY_INTERMEDIATE:
817       return "ok_cert_by_intermediate.pem";
818     case CERT_BAD_VALIDITY:
819       return "bad_validity.pem";
820     case CERT_TEST_NAMES:
821       return "test_names.pem";
822     case CERT_KEY_USAGE_RSA_ENCIPHERMENT:
823       return "key_usage_rsa_keyencipherment.pem";
824     case CERT_KEY_USAGE_RSA_DIGITAL_SIGNATURE:
825       return "key_usage_rsa_digitalsignature.pem";
826     case CERT_AUTO:
827       return std::string();
828   }
829 
830   return "ok_cert.pem";
831 }
832 
GetCertificate()833 scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() {
834   DCHECK(is_using_ssl_);
835   if (!x509_cert_) {
836     // Some tests want to get the certificate before the server has been
837     // initialized, so load it now if necessary. This is only possible if using
838     // a static certificate.
839     // TODO(mattm): change contract to require initializing first in all cases,
840     // update callers.
841     CHECK(UsingStaticCert());
842     // TODO(mattm): change contract to return nullptr on error instead of
843     // CHECKing, update callers.
844     CHECK(InitializeCertAndKeyFromFile());
845   }
846   return x509_cert_;
847 }
848 
GetGeneratedIntermediate()849 scoped_refptr<X509Certificate> EmbeddedTestServer::GetGeneratedIntermediate() {
850   DCHECK(is_using_ssl_);
851   DCHECK(!UsingStaticCert());
852   return intermediate_;
853 }
854 
ServeFilesFromDirectory(const base::FilePath & directory)855 void EmbeddedTestServer::ServeFilesFromDirectory(
856     const base::FilePath& directory) {
857   RegisterDefaultHandler(base::BindRepeating(&HandleFileRequest, directory));
858 }
859 
ServeFilesFromSourceDirectory(std::string_view relative)860 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
861     std::string_view relative) {
862   base::FilePath test_data_dir;
863   CHECK(base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir));
864   ServeFilesFromDirectory(test_data_dir.AppendASCII(relative));
865 }
866 
ServeFilesFromSourceDirectory(const base::FilePath & relative)867 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
868     const base::FilePath& relative) {
869   ServeFilesFromDirectory(GetFullPathFromSourceDirectory(relative));
870 }
871 
AddDefaultHandlers(const base::FilePath & directory)872 void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) {
873   ServeFilesFromSourceDirectory(directory);
874   AddDefaultHandlers();
875 }
876 
AddDefaultHandlers()877 void EmbeddedTestServer::AddDefaultHandlers() {
878   RegisterDefaultHandlers(this);
879 }
880 
GetFullPathFromSourceDirectory(const base::FilePath & relative)881 base::FilePath EmbeddedTestServer::GetFullPathFromSourceDirectory(
882     const base::FilePath& relative) {
883   base::FilePath test_data_dir;
884   CHECK(base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir));
885   return test_data_dir.Append(relative);
886 }
887 
RegisterRequestHandler(const HandleRequestCallback & callback)888 void EmbeddedTestServer::RegisterRequestHandler(
889     const HandleRequestCallback& callback) {
890   DCHECK(!io_thread_)
891       << "Handlers must be registered before starting the server.";
892   request_handlers_.push_back(callback);
893 }
894 
RegisterRequestMonitor(const MonitorRequestCallback & callback)895 void EmbeddedTestServer::RegisterRequestMonitor(
896     const MonitorRequestCallback& callback) {
897   DCHECK(!io_thread_)
898       << "Monitors must be registered before starting the server.";
899   request_monitors_.push_back(callback);
900 }
901 
RegisterDefaultHandler(const HandleRequestCallback & callback)902 void EmbeddedTestServer::RegisterDefaultHandler(
903     const HandleRequestCallback& callback) {
904   DCHECK(!io_thread_)
905       << "Handlers must be registered before starting the server.";
906   default_request_handlers_.push_back(callback);
907 }
908 
DoSSLUpgrade(std::unique_ptr<StreamSocket> connection)909 std::unique_ptr<SSLServerSocket> EmbeddedTestServer::DoSSLUpgrade(
910     std::unique_ptr<StreamSocket> connection) {
911   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
912 
913   return context_->CreateSSLServerSocket(std::move(connection));
914 }
915 
DoAcceptLoop()916 void EmbeddedTestServer::DoAcceptLoop() {
917   while (true) {
918     int rv = listen_socket_->Accept(
919         &accepted_socket_,
920         base::BindOnce(&EmbeddedTestServer::OnAcceptCompleted,
921                        base::Unretained(this)));
922     if (rv != OK)
923       return;
924 
925     HandleAcceptResult(std::move(accepted_socket_));
926   }
927 }
928 
FlushAllSocketsAndConnectionsOnUIThread()929 bool EmbeddedTestServer::FlushAllSocketsAndConnectionsOnUIThread() {
930   return PostTaskToIOThreadAndWait(
931       base::BindOnce(&EmbeddedTestServer::FlushAllSocketsAndConnections,
932                      base::Unretained(this)));
933 }
934 
FlushAllSocketsAndConnections()935 void EmbeddedTestServer::FlushAllSocketsAndConnections() {
936   connections_.clear();
937 }
938 
SetAlpsAcceptCH(std::string hostname,std::string accept_ch)939 void EmbeddedTestServer::SetAlpsAcceptCH(std::string hostname,
940                                          std::string accept_ch) {
941   alps_accept_ch_.insert_or_assign(std::move(hostname), std::move(accept_ch));
942 }
943 
OnAcceptCompleted(int rv)944 void EmbeddedTestServer::OnAcceptCompleted(int rv) {
945   DCHECK_NE(ERR_IO_PENDING, rv);
946   HandleAcceptResult(std::move(accepted_socket_));
947   DoAcceptLoop();
948 }
949 
OnHandshakeDone(HttpConnection * connection,int rv)950 void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) {
951   if (connection->Socket()->IsConnected()) {
952     connection->OnSocketReady();
953   } else {
954     RemoveConnection(connection);
955   }
956 }
957 
HandleAcceptResult(std::unique_ptr<StreamSocket> socket_ptr)958 void EmbeddedTestServer::HandleAcceptResult(
959     std::unique_ptr<StreamSocket> socket_ptr) {
960   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
961   if (connection_listener_)
962     socket_ptr = connection_listener_->AcceptedSocket(std::move(socket_ptr));
963 
964   if (!is_using_ssl_) {
965     AddConnection(std::move(socket_ptr))->OnSocketReady();
966     return;
967   }
968 
969   socket_ptr = DoSSLUpgrade(std::move(socket_ptr));
970 
971   StreamSocket* socket = socket_ptr.get();
972   HttpConnection* connection = AddConnection(std::move(socket_ptr));
973 
974   int rv = static_cast<SSLServerSocket*>(socket)->Handshake(
975       base::BindOnce(&EmbeddedTestServer::OnHandshakeDone,
976                      base::Unretained(this), connection));
977   if (rv != ERR_IO_PENDING)
978     OnHandshakeDone(connection, rv);
979 }
980 
AddConnection(std::unique_ptr<StreamSocket> socket_ptr)981 HttpConnection* EmbeddedTestServer::AddConnection(
982     std::unique_ptr<StreamSocket> socket_ptr) {
983   StreamSocket* socket = socket_ptr.get();
984   std::unique_ptr<HttpConnection> connection_ptr = HttpConnection::Create(
985       std::move(socket_ptr), connection_listener_, this, protocol_);
986   HttpConnection* connection = connection_ptr.get();
987   connections_[socket] = std::move(connection_ptr);
988 
989   return connection;
990 }
991 
RemoveConnection(HttpConnection * connection,EmbeddedTestServerConnectionListener * listener)992 void EmbeddedTestServer::RemoveConnection(
993     HttpConnection* connection,
994     EmbeddedTestServerConnectionListener* listener) {
995   DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
996   DCHECK(connection);
997   DCHECK_EQ(1u, connections_.count(connection->Socket()));
998 
999   StreamSocket* raw_socket = connection->Socket();
1000   std::unique_ptr<StreamSocket> socket = connection->TakeSocket();
1001   connections_.erase(raw_socket);
1002 
1003   if (listener && socket && socket->IsConnected())
1004     listener->OnResponseCompletedSuccessfully(std::move(socket));
1005 }
1006 
PostTaskToIOThreadAndWait(base::OnceClosure closure)1007 bool EmbeddedTestServer::PostTaskToIOThreadAndWait(base::OnceClosure closure) {
1008   // Note that PostTaskAndReply below requires
1009   // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
1010   // for posting the reply task. However, in order to make EmbeddedTestServer
1011   // universally usable, it needs to cope with the situation where it's running
1012   // on a thread on which a task executor is not (yet) available or has been
1013   // destroyed already.
1014   //
1015   // To handle this situation, create temporary task executor to support the
1016   // PostTaskAndReply operation if the current thread has no task executor.
1017   // TODO(mattm): Is this still necessary/desirable? Try removing this and see
1018   // if anything breaks.
1019   std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
1020   if (!base::CurrentThread::Get())
1021     temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
1022 
1023   base::RunLoop run_loop;
1024   if (!io_thread_->task_runner()->PostTaskAndReply(
1025           FROM_HERE, std::move(closure), run_loop.QuitClosure())) {
1026     return false;
1027   }
1028   run_loop.Run();
1029 
1030   return true;
1031 }
1032 
PostTaskToIOThreadAndWaitWithResult(base::OnceCallback<bool ()> task)1033 bool EmbeddedTestServer::PostTaskToIOThreadAndWaitWithResult(
1034     base::OnceCallback<bool()> task) {
1035   // Note that PostTaskAndReply below requires
1036   // base::SingleThreadTaskRunner::GetCurrentDefault() to return a task runner
1037   // for posting the reply task. However, in order to make EmbeddedTestServer
1038   // universally usable, it needs to cope with the situation where it's running
1039   // on a thread on which a task executor is not (yet) available or has been
1040   // destroyed already.
1041   //
1042   // To handle this situation, create temporary task executor to support the
1043   // PostTaskAndReply operation if the current thread has no task executor.
1044   // TODO(mattm): Is this still necessary/desirable? Try removing this and see
1045   // if anything breaks.
1046   std::unique_ptr<base::SingleThreadTaskExecutor> temporary_loop;
1047   if (!base::CurrentThread::Get())
1048     temporary_loop = std::make_unique<base::SingleThreadTaskExecutor>();
1049 
1050   base::RunLoop run_loop;
1051   bool task_result = false;
1052   if (!io_thread_->task_runner()->PostTaskAndReplyWithResult(
1053           FROM_HERE, std::move(task),
1054           base::BindOnce(base::BindLambdaForTesting([&](bool result) {
1055             task_result = result;
1056             run_loop.Quit();
1057           })))) {
1058     return false;
1059   }
1060   run_loop.Run();
1061 
1062   return task_result;
1063 }
1064 
1065 }  // namespace net::test_server
1066