xref: /aosp_15_r20/external/cronet/net/dns/address_sorter_win.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/address_sorter.h"
6 
7 #include <winsock2.h>
8 
9 #include <algorithm>
10 #include <utility>
11 #include <vector>
12 
13 #include "base/functional/bind.h"
14 #include "base/location.h"
15 #include "base/logging.h"
16 #include "base/memory/free_deleter.h"
17 #include "base/task/thread_pool.h"
18 #include "net/base/ip_address.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/winsock_init.h"
21 
22 namespace net {
23 
24 namespace {
25 
26 class AddressSorterWin : public AddressSorter {
27  public:
AddressSorterWin()28   AddressSorterWin() {
29     EnsureWinsockInit();
30   }
31 
32   AddressSorterWin(const AddressSorterWin&) = delete;
33   AddressSorterWin& operator=(const AddressSorterWin&) = delete;
34 
~AddressSorterWin()35   ~AddressSorterWin() override {}
36 
37   // AddressSorter:
Sort(const std::vector<IPEndPoint> & endpoints,CallbackType callback) const38   void Sort(const std::vector<IPEndPoint>& endpoints,
39             CallbackType callback) const override {
40     DCHECK(!endpoints.empty());
41     Job::Start(endpoints, std::move(callback));
42   }
43 
44  private:
45   // Executes the SIO_ADDRESS_LIST_SORT ioctl asynchronously, and
46   // performs the necessary conversions to/from `std::vector<IPEndPoint>`.
47   class Job : public base::RefCountedThreadSafe<Job> {
48    public:
Start(const std::vector<IPEndPoint> & endpoints,CallbackType callback)49     static void Start(const std::vector<IPEndPoint>& endpoints,
50                       CallbackType callback) {
51       auto job = base::WrapRefCounted(new Job(endpoints, std::move(callback)));
52       base::ThreadPool::PostTaskAndReply(
53           FROM_HERE,
54           {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
55           base::BindOnce(&Job::Run, job),
56           base::BindOnce(&Job::OnComplete, job));
57     }
58 
59     Job(const Job&) = delete;
60     Job& operator=(const Job&) = delete;
61 
62    private:
63     friend class base::RefCountedThreadSafe<Job>;
64 
Job(const std::vector<IPEndPoint> & endpoints,CallbackType callback)65     Job(const std::vector<IPEndPoint>& endpoints, CallbackType callback)
66         : callback_(std::move(callback)),
67           buffer_size_((sizeof(SOCKET_ADDRESS_LIST) +
68                         base::CheckedNumeric<DWORD>(endpoints.size()) *
69                             (sizeof(SOCKET_ADDRESS) + sizeof(SOCKADDR_STORAGE)))
70                            .ValueOrDie<DWORD>()),
71           input_buffer_(
72               reinterpret_cast<SOCKET_ADDRESS_LIST*>(malloc(buffer_size_))),
73           output_buffer_(
74               reinterpret_cast<SOCKET_ADDRESS_LIST*>(malloc(buffer_size_))) {
75       input_buffer_->iAddressCount = base::checked_cast<INT>(endpoints.size());
76       SOCKADDR_STORAGE* storage = reinterpret_cast<SOCKADDR_STORAGE*>(
77           input_buffer_->Address + input_buffer_->iAddressCount);
78 
79       for (size_t i = 0; i < endpoints.size(); ++i) {
80         IPEndPoint ipe = endpoints[i];
81         // Addresses must be sockaddr_in6.
82         if (ipe.address().IsIPv4()) {
83           ipe = IPEndPoint(ConvertIPv4ToIPv4MappedIPv6(ipe.address()),
84                            ipe.port());
85         }
86 
87         struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(storage + i);
88         socklen_t addr_len = sizeof(SOCKADDR_STORAGE);
89         bool result = ipe.ToSockAddr(addr, &addr_len);
90         DCHECK(result);
91         input_buffer_->Address[i].lpSockaddr = addr;
92         input_buffer_->Address[i].iSockaddrLength = addr_len;
93       }
94     }
95 
~Job()96     ~Job() {}
97 
98     // Executed asynchronously in ThreadPool.
Run()99     void Run() {
100       SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
101       if (sock == INVALID_SOCKET)
102         return;
103       DWORD result_size = 0;
104       int result = WSAIoctl(sock, SIO_ADDRESS_LIST_SORT, input_buffer_.get(),
105                             buffer_size_, output_buffer_.get(), buffer_size_,
106                             &result_size, nullptr, nullptr);
107       if (result == SOCKET_ERROR) {
108         LOG(ERROR) << "SIO_ADDRESS_LIST_SORT failed " << WSAGetLastError();
109       } else {
110         success_ = true;
111       }
112       closesocket(sock);
113     }
114 
115     // Executed on the calling thread.
OnComplete()116     void OnComplete() {
117       std::vector<IPEndPoint> sorted;
118       if (success_) {
119         sorted.reserve(output_buffer_->iAddressCount);
120         for (int i = 0; i < output_buffer_->iAddressCount; ++i) {
121           IPEndPoint ipe;
122           bool result =
123               ipe.FromSockAddr(output_buffer_->Address[i].lpSockaddr,
124                                output_buffer_->Address[i].iSockaddrLength);
125           DCHECK(result) << "Unable to roundtrip between IPEndPoint and "
126                          << "SOCKET_ADDRESS!";
127           // Unmap V4MAPPED IPv6 addresses so that Happy Eyeballs works.
128           if (ipe.address().IsIPv4MappedIPv6()) {
129             ipe = IPEndPoint(ConvertIPv4MappedIPv6ToIPv4(ipe.address()),
130                              ipe.port());
131           }
132           sorted.push_back(ipe);
133         }
134       }
135       std::move(callback_).Run(success_, std::move(sorted));
136     }
137 
138     CallbackType callback_;
139     const DWORD buffer_size_;
140     std::unique_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> input_buffer_;
141     std::unique_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> output_buffer_;
142     bool success_ = false;
143   };
144 };
145 
146 }  // namespace
147 
148 // static
CreateAddressSorter()149 std::unique_ptr<AddressSorter> AddressSorter::CreateAddressSorter() {
150   return std::make_unique<AddressSorterWin>();
151 }
152 
153 }  // namespace net
154