xref: /aosp_15_r20/external/cronet/net/test/test_doh_server.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2021 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/test/test_doh_server.h"
6 
7 #include <string.h>
8 
9 #include <memory>
10 #include <string_view>
11 
12 #include "base/base64url.h"
13 #include "base/check.h"
14 #include "base/functional/bind.h"
15 #include "base/logging.h"
16 #include "base/memory/scoped_refptr.h"
17 #include "base/ranges/algorithm.h"
18 #include "base/strings/string_number_conversions.h"
19 #include "base/synchronization/lock.h"
20 #include "base/time/time.h"
21 #include "net/base/io_buffer.h"
22 #include "net/base/url_util.h"
23 #include "net/dns/dns_names_util.h"
24 #include "net/dns/dns_query.h"
25 #include "net/dns/dns_response.h"
26 #include "net/dns/dns_test_util.h"
27 #include "net/dns/public/dns_protocol.h"
28 #include "net/http/http_status_code.h"
29 #include "net/test/embedded_test_server/embedded_test_server.h"
30 #include "net/test/embedded_test_server/http_request.h"
31 #include "net/test/embedded_test_server/http_response.h"
32 #include "url/gurl.h"
33 
34 namespace net {
35 
36 namespace {
37 
38 const char kPath[] = "/dns-query";
39 
MakeHttpErrorResponse(HttpStatusCode status,std::string_view error)40 std::unique_ptr<test_server::HttpResponse> MakeHttpErrorResponse(
41     HttpStatusCode status,
42     std::string_view error) {
43   auto response = std::make_unique<test_server::BasicHttpResponse>();
44   response->set_code(status);
45   response->set_content(std::string(error));
46   response->set_content_type("text/plain;charset=utf-8");
47   return response;
48 }
49 
MakeHttpResponseFromDns(const DnsResponse & dns_response)50 std::unique_ptr<test_server::HttpResponse> MakeHttpResponseFromDns(
51     const DnsResponse& dns_response) {
52   if (!dns_response.IsValid()) {
53     return MakeHttpErrorResponse(HTTP_INTERNAL_SERVER_ERROR,
54                                  "error making DNS response");
55   }
56 
57   auto response = std::make_unique<test_server::BasicHttpResponse>();
58   response->set_code(HTTP_OK);
59   response->set_content(std::string(dns_response.io_buffer()->data(),
60                                     dns_response.io_buffer_size()));
61   response->set_content_type("application/dns-message");
62   return response;
63 }
64 
65 }  // namespace
66 
TestDohServer()67 TestDohServer::TestDohServer() {
68   server_.RegisterRequestHandler(base::BindRepeating(
69       &TestDohServer::HandleRequest, base::Unretained(this)));
70 }
71 
72 TestDohServer::~TestDohServer() = default;
73 
SetHostname(std::string_view name)74 void TestDohServer::SetHostname(std::string_view name) {
75   DCHECK(!server_.Started());
76   hostname_ = std::string(name);
77 }
78 
SetFailRequests(bool fail_requests)79 void TestDohServer::SetFailRequests(bool fail_requests) {
80   base::AutoLock lock(lock_);
81   fail_requests_ = fail_requests;
82 }
83 
AddAddressRecord(std::string_view name,const IPAddress & address,base::TimeDelta ttl)84 void TestDohServer::AddAddressRecord(std::string_view name,
85                                      const IPAddress& address,
86                                      base::TimeDelta ttl) {
87   AddRecord(BuildTestAddressRecord(std::string(name), address, ttl));
88 }
89 
AddRecord(const DnsResourceRecord & record)90 void TestDohServer::AddRecord(const DnsResourceRecord& record) {
91   base::AutoLock lock(lock_);
92   records_.emplace(std::pair(record.name, record.type), record);
93 }
94 
Start()95 bool TestDohServer::Start() {
96   if (!InitializeAndListen()) {
97     return false;
98   }
99   StartAcceptingConnections();
100   return true;
101 }
102 
InitializeAndListen()103 bool TestDohServer::InitializeAndListen() {
104   if (hostname_) {
105     EmbeddedTestServer::ServerCertificateConfig cert_config;
106     cert_config.dns_names = {*hostname_};
107     server_.SetSSLConfig(cert_config);
108   } else {
109     // `CERT_OK` is valid for 127.0.0.1.
110     server_.SetSSLConfig(EmbeddedTestServer::CERT_OK);
111   }
112   return server_.InitializeAndListen();
113 }
114 
StartAcceptingConnections()115 void TestDohServer::StartAcceptingConnections() {
116   server_.StartAcceptingConnections();
117 }
118 
ShutdownAndWaitUntilComplete()119 bool TestDohServer::ShutdownAndWaitUntilComplete() {
120   return server_.ShutdownAndWaitUntilComplete();
121 }
122 
GetTemplate()123 std::string TestDohServer::GetTemplate() {
124   GURL url =
125       hostname_ ? server_.GetURL(*hostname_, kPath) : server_.GetURL(kPath);
126   return url.spec() + "{?dns}";
127 }
128 
GetPostOnlyTemplate()129 std::string TestDohServer::GetPostOnlyTemplate() {
130   GURL url =
131       hostname_ ? server_.GetURL(*hostname_, kPath) : server_.GetURL(kPath);
132   return url.spec();
133 }
134 
QueriesServed()135 int TestDohServer::QueriesServed() {
136   base::AutoLock lock(lock_);
137   return queries_served_;
138 }
139 
QueriesServedForSubdomains(std::string_view domain)140 int TestDohServer::QueriesServedForSubdomains(std::string_view domain) {
141   CHECK(net::dns_names_util::IsValidDnsName(domain));
142   auto is_subdomain = [&domain](std::string_view candidate) {
143     return net::IsSubdomainOf(candidate, domain);
144   };
145   base::AutoLock lock(lock_);
146   return base::ranges::count_if(query_qnames_, is_subdomain);
147 }
148 
HandleRequest(const test_server::HttpRequest & request)149 std::unique_ptr<test_server::HttpResponse> TestDohServer::HandleRequest(
150     const test_server::HttpRequest& request) {
151   GURL request_url = request.GetURL();
152   if (request_url.path_piece() != kPath) {
153     return nullptr;
154   }
155 
156   base::AutoLock lock(lock_);
157   queries_served_++;
158 
159   if (fail_requests_) {
160     return MakeHttpErrorResponse(HTTP_NOT_FOUND, "failed request");
161   }
162 
163   // See RFC 8484, Section 4.1.
164   std::string query;
165   if (request.method == test_server::METHOD_GET) {
166     std::string query_b64;
167     if (!GetValueForKeyInQuery(request_url, "dns", &query_b64) ||
168         !base::Base64UrlDecode(
169             query_b64, base::Base64UrlDecodePolicy::IGNORE_PADDING, &query)) {
170       return MakeHttpErrorResponse(HTTP_BAD_REQUEST,
171                                    "could not decode query string");
172     }
173   } else if (request.method == test_server::METHOD_POST) {
174     auto content_type = request.headers.find("content-type");
175     if (content_type == request.headers.end() ||
176         content_type->second != "application/dns-message") {
177       return MakeHttpErrorResponse(HTTP_BAD_REQUEST,
178                                    "unsupported content type");
179     }
180     query = request.content;
181   } else {
182     return MakeHttpErrorResponse(HTTP_BAD_REQUEST, "invalid method");
183   }
184 
185   // Parse the DNS query.
186   auto query_buf = base::MakeRefCounted<IOBufferWithSize>(query.size());
187   memcpy(query_buf->data(), query.data(), query.size());
188   DnsQuery dns_query(std::move(query_buf));
189   if (!dns_query.Parse(query.size())) {
190     return MakeHttpErrorResponse(HTTP_BAD_REQUEST, "invalid DNS query");
191   }
192 
193   std::optional<std::string> name = dns_names_util::NetworkToDottedName(
194       dns_query.qname(), /*require_complete=*/true);
195   if (!name) {
196     DnsResponse response(dns_query.id(), /*is_authoritative=*/false,
197                          /*answers=*/{}, /*authority_records=*/{},
198                          /*additional_records=*/{}, dns_query,
199                          dns_protocol::kRcodeFORMERR);
200     return MakeHttpResponseFromDns(response);
201   }
202   query_qnames_.push_back(*name);
203 
204   auto range = records_.equal_range(std::pair(*name, dns_query.qtype()));
205   std::vector<DnsResourceRecord> answers;
206   for (auto i = range.first; i != range.second; ++i) {
207     answers.push_back(i->second);
208   }
209 
210   LOG(INFO) << "Serving " << answers.size() << " records for " << *name
211             << ", qtype " << dns_query.qtype();
212 
213   // Note `answers` may be empty. NOERROR with no answers is how to express
214   // NODATA, so there is no need handle it specially.
215   //
216   // For now, this server does not support configuring additional records. When
217   // testing more complex HTTPS record cases, this will need to be extended.
218   //
219   // TODO(crbug.com/1251204): Add SOA records to test the default TTL.
220   DnsResponse response(dns_query.id(), /*is_authoritative=*/true,
221                        /*answers=*/answers, /*authority_records=*/{},
222                        /*additional_records=*/{}, dns_query);
223   return MakeHttpResponseFromDns(response);
224 }
225 
226 }  // namespace net
227