xref: /aosp_15_r20/external/grpc-grpc/src/core/lib/event_engine/windows/grpc_polled_fd_windows.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2023 The gRPC Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <grpc/support/port_platform.h>
16 
17 #include "src/core/lib/iomgr/port.h"  // IWYU pragma: keep
18 
19 #if GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER)
20 
21 #include <winsock2.h>
22 
23 #include <ares.h>
24 
25 #include "absl/functional/any_invocable.h"
26 #include "absl/status/status.h"
27 #include "absl/strings/str_format.h"
28 
29 #include <grpc/support/log_windows.h>
30 
31 #include "src/core/lib/address_utils/sockaddr_utils.h"
32 #include "src/core/lib/event_engine/ares_resolver.h"
33 #include "src/core/lib/event_engine/grpc_polled_fd.h"
34 #include "src/core/lib/event_engine/windows/grpc_polled_fd_windows.h"
35 #include "src/core/lib/event_engine/windows/win_socket.h"
36 #include "src/core/lib/gprpp/debug_location.h"
37 #include "src/core/lib/gprpp/sync.h"
38 #include "src/core/lib/iomgr/error.h"
39 #include "src/core/lib/slice/slice.h"
40 
41 // TODO(apolcyn): remove this hack after fixing upstream.
42 // Our grpc/c-ares code on Windows uses the ares_set_socket_functions API,
43 // which uses "struct iovec" type, which on Windows is defined inside of
44 // a c-ares header that is not public.
45 // See https://github.com/c-ares/c-ares/issues/206.
46 struct iovec {
47   void* iov_base;
48   size_t iov_len;
49 };
50 
51 namespace grpc_event_engine {
52 namespace experimental {
53 namespace {
54 
55 constexpr int kRecvFromSourceAddrSize = 200;
56 constexpr int kReadBufferSize = 4192;
57 
FlattenIovec(const struct iovec * iov,int iov_count)58 grpc_slice FlattenIovec(const struct iovec* iov, int iov_count) {
59   int total = 0;
60   for (int i = 0; i < iov_count; i++) {
61     total += iov[i].iov_len;
62   }
63   grpc_slice out = GRPC_SLICE_MALLOC(total);
64   size_t cur = 0;
65   for (int i = 0; i < iov_count; i++) {
66     for (size_t k = 0; k < iov[i].iov_len; k++) {
67       GRPC_SLICE_START_PTR(out)
68       [cur++] = (static_cast<char*>(iov[i].iov_base))[k];
69     }
70   }
71   return out;
72 }
73 
74 }  // namespace
75 
76 // c-ares reads and takes action on the error codes of the
77 // "virtual socket operations" in this file, via the WSAGetLastError
78 // APIs. If code in this file wants to set a specific WSA error that
79 // c-ares should read, it must do so by calling SetWSAError() on the
80 // WSAErrorContext instance passed to it. A WSAErrorContext must only be
81 // instantiated at the top of the virtual socket function callstack.
82 class WSAErrorContext {
83  public:
WSAErrorContext()84   explicit WSAErrorContext(){};
85 
~WSAErrorContext()86   ~WSAErrorContext() {
87     if (error_ != 0) {
88       WSASetLastError(error_);
89     }
90   }
91 
92   // Disallow copy and assignment operators
93   WSAErrorContext(const WSAErrorContext&) = delete;
94   WSAErrorContext& operator=(const WSAErrorContext&) = delete;
95 
SetWSAError(int error)96   void SetWSAError(int error) { error_ = error; }
97 
98  private:
99   int error_ = 0;
100 };
101 
102 // c-ares creates its own sockets and is meant to read them when readable and
103 // write them when writeable. To fit this socket usage model into the grpc
104 // windows poller (which gives notifications when attempted reads and writes
105 // are actually fulfilled rather than possible), this GrpcPolledFdWindows
106 // class takes advantage of the ares_set_socket_functions API and acts as a
107 // virtual socket. It holds its own read and write buffers which are written
108 // to and read from c-ares and are used with the grpc windows poller, and it,
109 // e.g., manufactures virtual socket error codes when it e.g. needs to tell
110 // the c-ares library to wait for an async read.
111 class GrpcPolledFdWindows : public GrpcPolledFd {
112  public:
GrpcPolledFdWindows(std::unique_ptr<WinSocket> winsocket,grpc_core::Mutex * mu,int address_family,int socket_type,EventEngine * event_engine)113   GrpcPolledFdWindows(std::unique_ptr<WinSocket> winsocket,
114                       grpc_core::Mutex* mu, int address_family, int socket_type,
115                       EventEngine* event_engine)
116       : name_(absl::StrFormat("c-ares socket: %" PRIdPTR,
117                               winsocket->raw_socket())),
118         address_family_(address_family),
119         socket_type_(socket_type),
120         mu_(mu),
121         winsocket_(std::move(winsocket)),
122         read_buf_(grpc_empty_slice()),
123         write_buf_(grpc_empty_slice()),
124         outer_read_closure_([this]() { OnIocpReadable(); }),
__anonfa030b620302() 125         outer_write_closure_([this]() { OnIocpWriteable(); }),
__anonfa030b620402() 126         on_tcp_connect_locked_([this]() { OnTcpConnect(); }),
127         event_engine_(event_engine) {}
128 
~GrpcPolledFdWindows()129   ~GrpcPolledFdWindows() override {
130     GRPC_ARES_RESOLVER_TRACE_LOG(
131         "fd:|%s| ~GrpcPolledFdWindows shutdown_called_: %d ", GetName(),
132         shutdown_called_);
133     grpc_core::CSliceUnref(read_buf_);
134     grpc_core::CSliceUnref(write_buf_);
135     GPR_ASSERT(read_closure_ == nullptr);
136     GPR_ASSERT(write_closure_ == nullptr);
137     if (!shutdown_called_) {
138       winsocket_->Shutdown(DEBUG_LOCATION, "~GrpcPolledFdWindows");
139     }
140   }
141 
RegisterForOnReadableLocked(absl::AnyInvocable<void (absl::Status)> read_closure)142   void RegisterForOnReadableLocked(
143       absl::AnyInvocable<void(absl::Status)> read_closure) override {
144     GPR_ASSERT(read_closure_ == nullptr);
145     read_closure_ = std::move(read_closure);
146     grpc_core::CSliceUnref(read_buf_);
147     GPR_ASSERT(!read_buf_has_data_);
148     read_buf_ = GRPC_SLICE_MALLOC(kReadBufferSize);
149     if (connect_done_) {
150       ContinueRegisterForOnReadableLocked();
151     } else {
152       GPR_ASSERT(pending_continue_register_for_on_readable_locked_ == false);
153       pending_continue_register_for_on_readable_locked_ = true;
154     }
155   }
156 
RegisterForOnWriteableLocked(absl::AnyInvocable<void (absl::Status)> write_closure)157   void RegisterForOnWriteableLocked(
158       absl::AnyInvocable<void(absl::Status)> write_closure) override {
159     if (socket_type_ == SOCK_DGRAM) {
160       GRPC_ARES_RESOLVER_TRACE_LOG(
161           "fd:|%s| RegisterForOnWriteableLocked called", GetName());
162     } else {
163       GPR_ASSERT(socket_type_ == SOCK_STREAM);
164       GRPC_ARES_RESOLVER_TRACE_LOG(
165           "fd:|%s| RegisterForOnWriteableLocked called tcp_write_state_: %d "
166           "connect_done_: %d",
167           GetName(), tcp_write_state_, connect_done_);
168     }
169     GPR_ASSERT(write_closure_ == nullptr);
170     write_closure_ = std::move(write_closure);
171     if (!connect_done_) {
172       GPR_ASSERT(!pending_continue_register_for_on_writeable_locked_);
173       pending_continue_register_for_on_writeable_locked_ = true;
174     } else {
175       ContinueRegisterForOnWriteableLocked();
176     }
177   }
178 
IsFdStillReadableLocked()179   bool IsFdStillReadableLocked() override { return read_buf_has_data_; }
180 
ShutdownLocked(absl::Status error)181   bool ShutdownLocked(absl::Status error) override {
182     GPR_ASSERT(!shutdown_called_);
183     if (!absl::IsCancelled(error)) {
184       return false;
185     }
186     GRPC_ARES_RESOLVER_TRACE_LOG("fd:|%s| ShutdownLocked", GetName());
187     shutdown_called_ = true;
188     // The socket is disconnected and closed here since this is an external
189     // cancel request, e.g. a timeout. c-ares shouldn't do anything on the
190     // socket after this point except calling close which should then destroy
191     // the GrpcPolledFdWindows object.
192     winsocket_->Shutdown(DEBUG_LOCATION, "GrpcPolledFdWindows::ShutdownLocked");
193     return true;
194   }
195 
GetWrappedAresSocketLocked()196   ares_socket_t GetWrappedAresSocketLocked() override {
197     return winsocket_->raw_socket();
198   }
199 
GetName() const200   const char* GetName() const override { return name_.c_str(); }
201 
RecvFrom(WSAErrorContext * wsa_error_ctx,void * data,ares_socket_t data_len,int,struct sockaddr * from,ares_socklen_t * from_len)202   ares_ssize_t RecvFrom(WSAErrorContext* wsa_error_ctx, void* data,
203                         ares_socket_t data_len, int /* flags */,
204                         struct sockaddr* from, ares_socklen_t* from_len) {
205     GRPC_ARES_RESOLVER_TRACE_LOG(
206         "fd:|%s| RecvFrom called read_buf_has_data:%d Current read buf "
207         "length:|%d|",
208         GetName(), read_buf_has_data_, GRPC_SLICE_LENGTH(read_buf_));
209     if (!read_buf_has_data_) {
210       wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
211       return -1;
212     }
213     ares_ssize_t bytes_read = 0;
214     for (size_t i = 0; i < GRPC_SLICE_LENGTH(read_buf_) && i < data_len; i++) {
215       (static_cast<char*>(data))[i] = GRPC_SLICE_START_PTR(read_buf_)[i];
216       bytes_read++;
217     }
218     read_buf_ = grpc_slice_sub_no_ref(read_buf_, bytes_read,
219                                       GRPC_SLICE_LENGTH(read_buf_));
220     if (GRPC_SLICE_LENGTH(read_buf_) == 0) {
221       read_buf_has_data_ = false;
222     }
223     // c-ares overloads this recv_from virtual socket function to receive
224     // data on both UDP and TCP sockets, and from is nullptr for TCP.
225     if (from != nullptr) {
226       GPR_ASSERT(*from_len <= recv_from_source_addr_len_);
227       memcpy(from, &recv_from_source_addr_, recv_from_source_addr_len_);
228       *from_len = recv_from_source_addr_len_;
229     }
230     return bytes_read;
231   }
232 
SendV(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)233   ares_ssize_t SendV(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
234                      int iov_count) {
235     GRPC_ARES_RESOLVER_TRACE_LOG(
236         "fd:|%s| SendV called connect_done_:%d wsa_connect_error_:%d",
237         GetName(), connect_done_, wsa_connect_error_);
238     if (!connect_done_) {
239       wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
240       return -1;
241     }
242     if (wsa_connect_error_ != 0) {
243       wsa_error_ctx->SetWSAError(wsa_connect_error_);
244       return -1;
245     }
246     switch (socket_type_) {
247       case SOCK_DGRAM:
248         return SendVUDP(wsa_error_ctx, iov, iov_count);
249       case SOCK_STREAM:
250         return SendVTCP(wsa_error_ctx, iov, iov_count);
251       default:
252         abort();
253     }
254   }
255 
Connect(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)256   int Connect(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
257               ares_socklen_t target_len) {
258     switch (socket_type_) {
259       case SOCK_DGRAM:
260         return ConnectUDP(wsa_error_ctx, target, target_len);
261       case SOCK_STREAM:
262         return ConnectTCP(wsa_error_ctx, target, target_len);
263       default:
264         grpc_core::Crash(
265             absl::StrFormat("Unknown socket_type_: %d", socket_type_));
266     }
267   }
268 
269  private:
270   enum WriteState {
271     WRITE_IDLE,
272     WRITE_REQUESTED,
273     WRITE_PENDING,
274     WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY,
275   };
276 
ScheduleAndNullReadClosure(absl::Status error)277   void ScheduleAndNullReadClosure(absl::Status error) {
278     event_engine_->Run([read_closure = std::move(read_closure_),
279                         error]() mutable { read_closure(error); });
280     read_closure_ = nullptr;
281   }
282 
ScheduleAndNullWriteClosure(absl::Status error)283   void ScheduleAndNullWriteClosure(absl::Status error) {
284     event_engine_->Run([write_closure = std::move(write_closure_),
285                         error]() mutable { write_closure(error); });
286     write_closure_ = nullptr;
287   }
288 
ContinueRegisterForOnReadableLocked()289   void ContinueRegisterForOnReadableLocked() {
290     GRPC_ARES_RESOLVER_TRACE_LOG(
291         "fd:|%s| ContinueRegisterForOnReadableLocked "
292         "wsa_connect_error_:%d",
293         GetName(), wsa_connect_error_);
294     GPR_ASSERT(connect_done_);
295     if (wsa_connect_error_ != 0) {
296       ScheduleAndNullReadClosure(GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
297       return;
298     }
299     WSABUF buffer;
300     buffer.buf = reinterpret_cast<char*>(GRPC_SLICE_START_PTR(read_buf_));
301     buffer.len = GRPC_SLICE_LENGTH(read_buf_);
302     recv_from_source_addr_len_ = sizeof(recv_from_source_addr_);
303     DWORD flags = 0;
304     winsocket_->NotifyOnRead(&outer_read_closure_);
305     if (WSARecvFrom(winsocket_->raw_socket(), &buffer, 1, nullptr, &flags,
306                     reinterpret_cast<sockaddr*>(recv_from_source_addr_),
307                     &recv_from_source_addr_len_,
308                     winsocket_->read_info()->overlapped(), nullptr) != 0) {
309       int wsa_last_error = WSAGetLastError();
310       char* msg = gpr_format_message(wsa_last_error);
311       GRPC_ARES_RESOLVER_TRACE_LOG(
312           "fd:|%s| ContinueRegisterForOnReadableLocked WSARecvFrom error "
313           "code:|%d| "
314           "msg:|%s|",
315           GetName(), wsa_last_error, msg);
316       gpr_free(msg);
317       if (wsa_last_error != WSA_IO_PENDING) {
318         winsocket_->UnregisterReadCallback();
319         ScheduleAndNullReadClosure(
320             GRPC_WSA_ERROR(wsa_last_error, "WSARecvFrom"));
321         return;
322       }
323     }
324   }
325 
ContinueRegisterForOnWriteableLocked()326   void ContinueRegisterForOnWriteableLocked() {
327     GRPC_ARES_RESOLVER_TRACE_LOG(
328         "fd:|%s| ContinueRegisterForOnWriteableLocked "
329         "wsa_connect_error_:%d",
330         GetName(), wsa_connect_error_);
331     GPR_ASSERT(connect_done_);
332     if (wsa_connect_error_ != 0) {
333       ScheduleAndNullWriteClosure(
334           GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
335       return;
336     }
337     if (socket_type_ == SOCK_DGRAM) {
338       ScheduleAndNullWriteClosure(absl::OkStatus());
339       return;
340     }
341     GPR_ASSERT(socket_type_ == SOCK_STREAM);
342     int wsa_error_code = 0;
343     switch (tcp_write_state_) {
344       case WRITE_IDLE:
345         ScheduleAndNullWriteClosure(absl::OkStatus());
346         break;
347       case WRITE_REQUESTED:
348         tcp_write_state_ = WRITE_PENDING;
349         winsocket_->NotifyOnWrite(&outer_write_closure_);
350         if (SendWriteBuf(nullptr, winsocket_->write_info()->overlapped(),
351                          &wsa_error_code) != 0) {
352           winsocket_->UnregisterWriteCallback();
353           ScheduleAndNullWriteClosure(
354               GRPC_WSA_ERROR(wsa_error_code, "WSASend (overlapped)"));
355           return;
356         }
357         break;
358       case WRITE_PENDING:
359       case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
360         grpc_core::Crash(
361             absl::StrFormat("Invalid tcp_write_state_: %d", tcp_write_state_));
362     }
363   }
364 
SendWriteBuf(LPDWORD bytes_sent_ptr,LPWSAOVERLAPPED overlapped,int * wsa_error_code)365   int SendWriteBuf(LPDWORD bytes_sent_ptr, LPWSAOVERLAPPED overlapped,
366                    int* wsa_error_code) {
367     WSABUF buf;
368     buf.len = GRPC_SLICE_LENGTH(write_buf_);
369     buf.buf = reinterpret_cast<char*>(GRPC_SLICE_START_PTR(write_buf_));
370     DWORD flags = 0;
371     int out = WSASend(winsocket_->raw_socket(), &buf, 1, bytes_sent_ptr, flags,
372                       overlapped, nullptr);
373     *wsa_error_code = WSAGetLastError();
374     GRPC_ARES_RESOLVER_TRACE_LOG(
375         "fd:|%s| SendWriteBuf WSASend buf.len:%d *bytes_sent_ptr:%d "
376         "overlapped:%p "
377         "return:%d *wsa_error_code:%d",
378         GetName(), buf.len, bytes_sent_ptr != nullptr ? *bytes_sent_ptr : 0,
379         overlapped, out, *wsa_error_code);
380     return out;
381   }
382 
SendVUDP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)383   ares_ssize_t SendVUDP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
384                         int iov_count) {
385     // c-ares doesn't handle retryable errors on writes of UDP sockets.
386     // Therefore, the sendv handler for UDP sockets must only attempt
387     // to write everything inline.
388     GRPC_ARES_RESOLVER_TRACE_LOG("fd:|%s| SendVUDP called", GetName());
389     GPR_ASSERT(GRPC_SLICE_LENGTH(write_buf_) == 0);
390     grpc_core::CSliceUnref(write_buf_);
391     write_buf_ = FlattenIovec(iov, iov_count);
392     DWORD bytes_sent = 0;
393     int wsa_error_code = 0;
394     if (SendWriteBuf(&bytes_sent, nullptr, &wsa_error_code) != 0) {
395       grpc_core::CSliceUnref(write_buf_);
396       write_buf_ = grpc_empty_slice();
397       wsa_error_ctx->SetWSAError(wsa_error_code);
398       char* msg = gpr_format_message(wsa_error_code);
399       GRPC_ARES_RESOLVER_TRACE_LOG(
400           "fd:|%s| SendVUDP SendWriteBuf error code:%d msg:|%s|", GetName(),
401           wsa_error_code, msg);
402       gpr_free(msg);
403       return -1;
404     }
405     write_buf_ = grpc_slice_sub_no_ref(write_buf_, bytes_sent,
406                                        GRPC_SLICE_LENGTH(write_buf_));
407     return bytes_sent;
408   }
409 
SendVTCP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)410   ares_ssize_t SendVTCP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
411                         int iov_count) {
412     // The "sendv" handler on TCP sockets buffers up write
413     // requests and returns an artificial WSAEWOULDBLOCK. Writing that buffer
414     // out in the background, and making further send progress in general, will
415     // happen as long as c-ares continues to show interest in writeability on
416     // this fd.
417     GRPC_ARES_RESOLVER_TRACE_LOG("fd:|%s| SendVTCP called tcp_write_state_:%d",
418                                  GetName(), tcp_write_state_);
419     switch (tcp_write_state_) {
420       case WRITE_IDLE:
421         tcp_write_state_ = WRITE_REQUESTED;
422         grpc_core::CSliceUnref(write_buf_);
423         write_buf_ = FlattenIovec(iov, iov_count);
424         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
425         return -1;
426       case WRITE_REQUESTED:
427       case WRITE_PENDING:
428         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
429         return -1;
430       case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
431         // c-ares is retrying a send on data that we previously returned
432         // WSAEWOULDBLOCK for, but then subsequently wrote out in the
433         // background. Right now, we assume that c-ares is retrying the same
434         // send again. If c-ares still needs to send even more data, we'll get
435         // to it eventually.
436         grpc_slice currently_attempted = FlattenIovec(iov, iov_count);
437         GPR_ASSERT(GRPC_SLICE_LENGTH(currently_attempted) >=
438                    GRPC_SLICE_LENGTH(write_buf_));
439         ares_ssize_t total_sent = 0;
440         for (size_t i = 0; i < GRPC_SLICE_LENGTH(write_buf_); i++) {
441           GPR_ASSERT(GRPC_SLICE_START_PTR(currently_attempted)[i] ==
442                      GRPC_SLICE_START_PTR(write_buf_)[i]);
443           total_sent++;
444         }
445         grpc_core::CSliceUnref(currently_attempted);
446         tcp_write_state_ = WRITE_IDLE;
447         return total_sent;
448     }
449     grpc_core::Crash(
450         absl::StrFormat("Unknown tcp_write_state_: %d", tcp_write_state_));
451   }
452 
OnTcpConnect()453   void OnTcpConnect() {
454     grpc_core::MutexLock lock(mu_);
455     GRPC_ARES_RESOLVER_TRACE_LOG(
456         "fd:%s InnerOnTcpConnectLocked "
457         "pending_register_for_readable:%d"
458         " pending_register_for_writeable:%d",
459         GetName(), pending_continue_register_for_on_readable_locked_,
460         pending_continue_register_for_on_writeable_locked_);
461     GPR_ASSERT(!connect_done_);
462     connect_done_ = true;
463     GPR_ASSERT(wsa_connect_error_ == 0);
464     if (shutdown_called_) {
465       wsa_connect_error_ = WSA_OPERATION_ABORTED;
466     } else {
467       DWORD transferred_bytes = 0;
468       DWORD flags;
469       BOOL wsa_success = WSAGetOverlappedResult(
470           winsocket_->raw_socket(), winsocket_->write_info()->overlapped(),
471           &transferred_bytes, FALSE, &flags);
472       GPR_ASSERT(transferred_bytes == 0);
473       if (!wsa_success) {
474         wsa_connect_error_ = WSAGetLastError();
475         char* msg = gpr_format_message(wsa_connect_error_);
476         GRPC_ARES_RESOLVER_TRACE_LOG(
477             "fd:%s InnerOnTcpConnectLocked WSA overlapped result code:%d "
478             "msg:|%s|",
479             GetName(), wsa_connect_error_, msg);
480         gpr_free(msg);
481       }
482     }
483     if (pending_continue_register_for_on_readable_locked_) {
484       ContinueRegisterForOnReadableLocked();
485     }
486     if (pending_continue_register_for_on_writeable_locked_) {
487       ContinueRegisterForOnWriteableLocked();
488     }
489   }
490 
ConnectUDP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)491   int ConnectUDP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
492                  ares_socklen_t target_len) {
493     GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s ConnectUDP", GetName());
494     GPR_ASSERT(!connect_done_);
495     GPR_ASSERT(wsa_connect_error_ == 0);
496     SOCKET s = winsocket_->raw_socket();
497     int out =
498         WSAConnect(s, target, target_len, nullptr, nullptr, nullptr, nullptr);
499     wsa_connect_error_ = WSAGetLastError();
500     wsa_error_ctx->SetWSAError(wsa_connect_error_);
501     connect_done_ = true;
502     char* msg = gpr_format_message(wsa_connect_error_);
503     GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s WSAConnect error code:|%d| msg:|%s|",
504                                  GetName(), wsa_connect_error_, msg);
505     gpr_free(msg);
506     // c-ares expects a posix-style connect API
507     return out == 0 ? 0 : -1;
508   }
509 
ConnectTCP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)510   int ConnectTCP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
511                  ares_socklen_t target_len) {
512     GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s ConnectTCP", GetName());
513     LPFN_CONNECTEX ConnectEx;
514     GUID guid = WSAID_CONNECTEX;
515     DWORD ioctl_num_bytes;
516     SOCKET s = winsocket_->raw_socket();
517     if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
518                  &ConnectEx, sizeof(ConnectEx), &ioctl_num_bytes, nullptr,
519                  nullptr) != 0) {
520       int wsa_last_error = WSAGetLastError();
521       wsa_error_ctx->SetWSAError(wsa_last_error);
522       char* msg = gpr_format_message(wsa_last_error);
523       GRPC_ARES_RESOLVER_TRACE_LOG(
524           "fd:%s WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER) error code:%d "
525           "msg:|%s|",
526           GetName(), wsa_last_error, msg);
527       gpr_free(msg);
528       connect_done_ = true;
529       wsa_connect_error_ = wsa_last_error;
530       return -1;
531     }
532     grpc_resolved_address wildcard4_addr;
533     grpc_resolved_address wildcard6_addr;
534     grpc_sockaddr_make_wildcards(0, &wildcard4_addr, &wildcard6_addr);
535     grpc_resolved_address* local_address = nullptr;
536     if (address_family_ == AF_INET) {
537       local_address = &wildcard4_addr;
538     } else {
539       local_address = &wildcard6_addr;
540     }
541     if (bind(s, reinterpret_cast<struct sockaddr*>(local_address->addr),
542              static_cast<int>(local_address->len)) != 0) {
543       int wsa_last_error = WSAGetLastError();
544       wsa_error_ctx->SetWSAError(wsa_last_error);
545       char* msg = gpr_format_message(wsa_last_error);
546       GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s bind error code:%d msg:|%s|",
547                                    GetName(), wsa_last_error, msg);
548       gpr_free(msg);
549       connect_done_ = true;
550       wsa_connect_error_ = wsa_last_error;
551       return -1;
552     }
553     int out = 0;
554     // Register an async OnTcpConnect callback here since it is required by the
555     // WinSocket API.
556     winsocket_->NotifyOnWrite(&on_tcp_connect_locked_);
557     if (ConnectEx(s, target, target_len, nullptr, 0, nullptr,
558                   winsocket_->write_info()->overlapped()) == 0) {
559       out = -1;
560       int wsa_last_error = WSAGetLastError();
561       wsa_error_ctx->SetWSAError(wsa_last_error);
562       char* msg = gpr_format_message(wsa_last_error);
563       GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s ConnectEx error code:%d msg:|%s|",
564                                    GetName(), wsa_last_error, msg);
565       gpr_free(msg);
566       if (wsa_last_error == WSA_IO_PENDING) {
567         // c-ares only understands WSAEINPROGRESS and EWOULDBLOCK error codes on
568         // connect, but an async connect on IOCP socket will give
569         // WSA_IO_PENDING, so we need to convert.
570         wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
571       } else {
572         winsocket_->UnregisterWriteCallback();
573         // By returning a non-retryable error to c-ares at this point,
574         // we're aborting the possibility of any future operations on this fd.
575         connect_done_ = true;
576         wsa_connect_error_ = wsa_last_error;
577         return -1;
578       }
579     }
580     return out;
581   }
582 
583   // TODO(apolcyn): improve this error handling to be less conversative.
584   // An e.g. ECONNRESET error here should result in errors when
585   // c-ares reads from this socket later, but it shouldn't necessarily cancel
586   // the entire resolution attempt. Doing so will allow the "inject broken
587   // nameserver list" test to pass on Windows.
OnIocpReadable()588   void OnIocpReadable() {
589     grpc_core::MutexLock lock(mu_);
590     absl::Status error;
591     if (winsocket_->read_info()->result().wsa_error != 0) {
592       // WSAEMSGSIZE would be due to receiving more data
593       // than our read buffer's fixed capacity. Assume that
594       // the connection is TCP and read the leftovers
595       // in subsequent c-ares reads.
596       if (winsocket_->read_info()->result().wsa_error != WSAEMSGSIZE) {
597         error = GRPC_WSA_ERROR(winsocket_->read_info()->result().wsa_error,
598                                "OnIocpReadableInner");
599         GRPC_ARES_RESOLVER_TRACE_LOG(
600             "fd:|%s| OnIocpReadableInner winsocket_->read_info.wsa_error "
601             "code:|%d| msg:|%s|",
602             GetName(), winsocket_->read_info()->result().wsa_error,
603             grpc_core::StatusToString(error).c_str());
604       }
605     }
606     if (error.ok()) {
607       read_buf_ = grpc_slice_sub_no_ref(
608           read_buf_, 0, winsocket_->read_info()->result().bytes_transferred);
609       read_buf_has_data_ = true;
610     } else {
611       grpc_core::CSliceUnref(read_buf_);
612       read_buf_ = grpc_empty_slice();
613     }
614     GRPC_ARES_RESOLVER_TRACE_LOG(
615         "fd:|%s| OnIocpReadable finishing. read buf length now:|%d|", GetName(),
616         GRPC_SLICE_LENGTH(read_buf_));
617     ScheduleAndNullReadClosure(error);
618   }
619 
OnIocpWriteable()620   void OnIocpWriteable() {
621     grpc_core::MutexLock lock(mu_);
622     GRPC_ARES_RESOLVER_TRACE_LOG("OnIocpWriteableInner. fd:|%s|", GetName());
623     GPR_ASSERT(socket_type_ == SOCK_STREAM);
624     absl::Status error;
625     if (winsocket_->write_info()->result().wsa_error != 0) {
626       error = GRPC_WSA_ERROR(winsocket_->write_info()->result().wsa_error,
627                              "OnIocpWriteableInner");
628       GRPC_ARES_RESOLVER_TRACE_LOG(
629           "fd:|%s| OnIocpWriteableInner. winsocket_->write_info.wsa_error "
630           "code:|%d| msg:|%s|",
631           GetName(), winsocket_->write_info()->result().wsa_error,
632           grpc_core::StatusToString(error).c_str());
633     }
634     GPR_ASSERT(tcp_write_state_ == WRITE_PENDING);
635     if (error.ok()) {
636       tcp_write_state_ = WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY;
637       write_buf_ = grpc_slice_sub_no_ref(
638           write_buf_, 0, winsocket_->write_info()->result().bytes_transferred);
639       GRPC_ARES_RESOLVER_TRACE_LOG(
640           "fd:|%s| OnIocpWriteableInner. bytes transferred:%d", GetName(),
641           winsocket_->write_info()->result().bytes_transferred);
642     } else {
643       grpc_core::CSliceUnref(write_buf_);
644       write_buf_ = grpc_empty_slice();
645     }
646     ScheduleAndNullWriteClosure(error);
647   }
648 
649   const std::string name_;
650   const int address_family_;
651   const int socket_type_;
652   grpc_core::Mutex* mu_;
653   std::unique_ptr<WinSocket> winsocket_;
654   char recv_from_source_addr_[kRecvFromSourceAddrSize];
655   ares_socklen_t recv_from_source_addr_len_;
656   grpc_slice read_buf_;
657   bool read_buf_has_data_ = false;
658   grpc_slice write_buf_;
659   absl::AnyInvocable<void(absl::Status)> read_closure_;
660   absl::AnyInvocable<void(absl::Status)> write_closure_;
661   AnyInvocableClosure outer_read_closure_;
662   AnyInvocableClosure outer_write_closure_;
663   bool shutdown_called_ = false;
664   // State related to TCP sockets
665   AnyInvocableClosure on_tcp_connect_locked_;
666   bool connect_done_ = false;
667   int wsa_connect_error_ = 0;
668   WriteState tcp_write_state_ = WRITE_IDLE;
669   // We don't run register_for_{readable,writeable} logic until
670   // a socket is connected. In the interim, we queue readable/writeable
671   // registrations with the following state.
672   bool pending_continue_register_for_on_readable_locked_ = false;
673   bool pending_continue_register_for_on_writeable_locked_ = false;
674   // This pointer is initialized from the stored pointer inside the shared
675   // pointer owned by the AresResolver and should be valid at the time of use.
676   EventEngine* event_engine_;
677 };
678 
679 // These virtual socket functions are called from within the c-ares
680 // library. These methods generally dispatch those socket calls to the
681 // appropriate methods. The virtual "socket" and "close" methods are
682 // special and instead create/add and remove/destroy GrpcPolledFdWindows
683 // objects.
684 class CustomSockFuncs {
685  public:
Socket(int af,int type,int protocol,void * user_data)686   static ares_socket_t Socket(int af, int type, int protocol, void* user_data) {
687     if (type != SOCK_DGRAM && type != SOCK_STREAM) {
688       GRPC_ARES_RESOLVER_TRACE_LOG("Socket called with invalid socket type:%d",
689                                    type);
690       return INVALID_SOCKET;
691     }
692     GrpcPolledFdFactoryWindows* self =
693         static_cast<GrpcPolledFdFactoryWindows*>(user_data);
694     SOCKET s = WSASocket(af, type, protocol, nullptr, 0,
695                          IOCP::GetDefaultSocketFlags());
696     if (s == INVALID_SOCKET) {
697       GRPC_ARES_RESOLVER_TRACE_LOG(
698           "WSASocket failed with params af:%d type:%d protocol:%d", af, type,
699           protocol);
700       return INVALID_SOCKET;
701     }
702     if (type == SOCK_STREAM) {
703       absl::Status error = PrepareSocket(s);
704       if (!error.ok()) {
705         GRPC_ARES_RESOLVER_TRACE_LOG("WSAIoctl failed with error: %s",
706                                      grpc_core::StatusToString(error).c_str());
707         return INVALID_SOCKET;
708       }
709     }
710     auto polled_fd = std::make_unique<GrpcPolledFdWindows>(
711         self->iocp_->Watch(s), self->mu_, af, type, self->event_engine_);
712     GRPC_ARES_RESOLVER_TRACE_LOG(
713         "fd:|%s| created with params af:%d type:%d protocol:%d",
714         polled_fd->GetName(), af, type, protocol);
715     GPR_ASSERT(self->sockets_.insert({s, std::move(polled_fd)}).second);
716     return s;
717   }
718 
Connect(ares_socket_t as,const struct sockaddr * target,ares_socklen_t target_len,void * user_data)719   static int Connect(ares_socket_t as, const struct sockaddr* target,
720                      ares_socklen_t target_len, void* user_data) {
721     WSAErrorContext wsa_error_ctx;
722     GrpcPolledFdFactoryWindows* self =
723         static_cast<GrpcPolledFdFactoryWindows*>(user_data);
724     auto it = self->sockets_.find(as);
725     GPR_ASSERT(it != self->sockets_.end());
726     return it->second->Connect(&wsa_error_ctx, target, target_len);
727   }
728 
SendV(ares_socket_t as,const struct iovec * iov,int iovec_count,void * user_data)729   static ares_ssize_t SendV(ares_socket_t as, const struct iovec* iov,
730                             int iovec_count, void* user_data) {
731     WSAErrorContext wsa_error_ctx;
732     GrpcPolledFdFactoryWindows* self =
733         static_cast<GrpcPolledFdFactoryWindows*>(user_data);
734     auto it = self->sockets_.find(as);
735     GPR_ASSERT(it != self->sockets_.end());
736     return it->second->SendV(&wsa_error_ctx, iov, iovec_count);
737   }
738 
RecvFrom(ares_socket_t as,void * data,size_t data_len,int flags,struct sockaddr * from,ares_socklen_t * from_len,void * user_data)739   static ares_ssize_t RecvFrom(ares_socket_t as, void* data, size_t data_len,
740                                int flags, struct sockaddr* from,
741                                ares_socklen_t* from_len, void* user_data) {
742     WSAErrorContext wsa_error_ctx;
743     GrpcPolledFdFactoryWindows* self =
744         static_cast<GrpcPolledFdFactoryWindows*>(user_data);
745     auto it = self->sockets_.find(as);
746     GPR_ASSERT(it != self->sockets_.end());
747     return it->second->RecvFrom(&wsa_error_ctx, data, data_len, flags, from,
748                                 from_len);
749   }
750 
CloseSocket(SOCKET s,void *)751   static int CloseSocket(SOCKET s, void*) {
752     GRPC_ARES_RESOLVER_TRACE_LOG("c-ares socket: %d CloseSocket", s);
753     return 0;
754   }
755 };
756 
757 // Adapter to hold the ownership of GrpcPolledFdWindows internally.
758 class GrpcPolledFdWrapper : public GrpcPolledFd {
759  public:
GrpcPolledFdWrapper(GrpcPolledFdWindows * polled_fd)760   explicit GrpcPolledFdWrapper(GrpcPolledFdWindows* polled_fd)
761       : polled_fd_(polled_fd) {}
762 
RegisterForOnReadableLocked(absl::AnyInvocable<void (absl::Status)> read_closure)763   void RegisterForOnReadableLocked(
764       absl::AnyInvocable<void(absl::Status)> read_closure) override {
765     polled_fd_->RegisterForOnReadableLocked(std::move(read_closure));
766   }
767 
RegisterForOnWriteableLocked(absl::AnyInvocable<void (absl::Status)> write_closure)768   void RegisterForOnWriteableLocked(
769       absl::AnyInvocable<void(absl::Status)> write_closure) override {
770     polled_fd_->RegisterForOnWriteableLocked(std::move(write_closure));
771   }
772 
IsFdStillReadableLocked()773   bool IsFdStillReadableLocked() override {
774     return polled_fd_->IsFdStillReadableLocked();
775   }
776 
ShutdownLocked(absl::Status error)777   bool ShutdownLocked(absl::Status error) override {
778     return polled_fd_->ShutdownLocked(error);
779   }
780 
GetWrappedAresSocketLocked()781   ares_socket_t GetWrappedAresSocketLocked() override {
782     return polled_fd_->GetWrappedAresSocketLocked();
783   }
784 
GetName() const785   const char* GetName() const override { return polled_fd_->GetName(); }
786 
787  private:
788   GrpcPolledFdWindows* polled_fd_;
789 };
790 
GrpcPolledFdFactoryWindows(IOCP * iocp)791 GrpcPolledFdFactoryWindows::GrpcPolledFdFactoryWindows(IOCP* iocp)
792     : iocp_(iocp) {}
793 
~GrpcPolledFdFactoryWindows()794 GrpcPolledFdFactoryWindows::~GrpcPolledFdFactoryWindows() {}
795 
Initialize(grpc_core::Mutex * mutex,EventEngine * event_engine)796 void GrpcPolledFdFactoryWindows::Initialize(grpc_core::Mutex* mutex,
797                                             EventEngine* event_engine) {
798   mu_ = mutex;
799   event_engine_ = event_engine;
800 }
801 
NewGrpcPolledFdLocked(ares_socket_t as)802 std::unique_ptr<GrpcPolledFd> GrpcPolledFdFactoryWindows::NewGrpcPolledFdLocked(
803     ares_socket_t as) {
804   auto it = sockets_.find(as);
805   GPR_ASSERT(it != sockets_.end());
806   return std::make_unique<GrpcPolledFdWrapper>(it->second.get());
807 }
808 
ConfigureAresChannelLocked(ares_channel channel)809 void GrpcPolledFdFactoryWindows::ConfigureAresChannelLocked(
810     ares_channel channel) {
811   static const struct ares_socket_functions kCustomSockFuncs = {
812       /*asocket=*/&CustomSockFuncs::Socket,
813       /*aclose=*/&CustomSockFuncs::CloseSocket,
814       /*aconnect=*/&CustomSockFuncs::Connect,
815       /*arecvfrom=*/&CustomSockFuncs::RecvFrom,
816       /*asendv=*/&CustomSockFuncs::SendV,
817   };
818   ares_set_socket_functions(channel, &kCustomSockFuncs, this);
819 }
820 
821 }  // namespace experimental
822 }  // namespace grpc_event_engine
823 
824 #endif  // GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER)
825