1 // Copyright 2023 The gRPC Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <grpc/support/port_platform.h>
16
17 #include "src/core/lib/iomgr/port.h" // IWYU pragma: keep
18
19 #if GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER)
20
21 #include <winsock2.h>
22
23 #include <ares.h>
24
25 #include "absl/functional/any_invocable.h"
26 #include "absl/status/status.h"
27 #include "absl/strings/str_format.h"
28
29 #include <grpc/support/log_windows.h>
30
31 #include "src/core/lib/address_utils/sockaddr_utils.h"
32 #include "src/core/lib/event_engine/ares_resolver.h"
33 #include "src/core/lib/event_engine/grpc_polled_fd.h"
34 #include "src/core/lib/event_engine/windows/grpc_polled_fd_windows.h"
35 #include "src/core/lib/event_engine/windows/win_socket.h"
36 #include "src/core/lib/gprpp/debug_location.h"
37 #include "src/core/lib/gprpp/sync.h"
38 #include "src/core/lib/iomgr/error.h"
39 #include "src/core/lib/slice/slice.h"
40
41 // TODO(apolcyn): remove this hack after fixing upstream.
42 // Our grpc/c-ares code on Windows uses the ares_set_socket_functions API,
43 // which uses "struct iovec" type, which on Windows is defined inside of
44 // a c-ares header that is not public.
45 // See https://github.com/c-ares/c-ares/issues/206.
46 struct iovec {
47 void* iov_base;
48 size_t iov_len;
49 };
50
51 namespace grpc_event_engine {
52 namespace experimental {
53 namespace {
54
55 constexpr int kRecvFromSourceAddrSize = 200;
56 constexpr int kReadBufferSize = 4192;
57
FlattenIovec(const struct iovec * iov,int iov_count)58 grpc_slice FlattenIovec(const struct iovec* iov, int iov_count) {
59 int total = 0;
60 for (int i = 0; i < iov_count; i++) {
61 total += iov[i].iov_len;
62 }
63 grpc_slice out = GRPC_SLICE_MALLOC(total);
64 size_t cur = 0;
65 for (int i = 0; i < iov_count; i++) {
66 for (size_t k = 0; k < iov[i].iov_len; k++) {
67 GRPC_SLICE_START_PTR(out)
68 [cur++] = (static_cast<char*>(iov[i].iov_base))[k];
69 }
70 }
71 return out;
72 }
73
74 } // namespace
75
76 // c-ares reads and takes action on the error codes of the
77 // "virtual socket operations" in this file, via the WSAGetLastError
78 // APIs. If code in this file wants to set a specific WSA error that
79 // c-ares should read, it must do so by calling SetWSAError() on the
80 // WSAErrorContext instance passed to it. A WSAErrorContext must only be
81 // instantiated at the top of the virtual socket function callstack.
82 class WSAErrorContext {
83 public:
WSAErrorContext()84 explicit WSAErrorContext(){};
85
~WSAErrorContext()86 ~WSAErrorContext() {
87 if (error_ != 0) {
88 WSASetLastError(error_);
89 }
90 }
91
92 // Disallow copy and assignment operators
93 WSAErrorContext(const WSAErrorContext&) = delete;
94 WSAErrorContext& operator=(const WSAErrorContext&) = delete;
95
SetWSAError(int error)96 void SetWSAError(int error) { error_ = error; }
97
98 private:
99 int error_ = 0;
100 };
101
102 // c-ares creates its own sockets and is meant to read them when readable and
103 // write them when writeable. To fit this socket usage model into the grpc
104 // windows poller (which gives notifications when attempted reads and writes
105 // are actually fulfilled rather than possible), this GrpcPolledFdWindows
106 // class takes advantage of the ares_set_socket_functions API and acts as a
107 // virtual socket. It holds its own read and write buffers which are written
108 // to and read from c-ares and are used with the grpc windows poller, and it,
109 // e.g., manufactures virtual socket error codes when it e.g. needs to tell
110 // the c-ares library to wait for an async read.
111 class GrpcPolledFdWindows : public GrpcPolledFd {
112 public:
GrpcPolledFdWindows(std::unique_ptr<WinSocket> winsocket,grpc_core::Mutex * mu,int address_family,int socket_type,EventEngine * event_engine)113 GrpcPolledFdWindows(std::unique_ptr<WinSocket> winsocket,
114 grpc_core::Mutex* mu, int address_family, int socket_type,
115 EventEngine* event_engine)
116 : name_(absl::StrFormat("c-ares socket: %" PRIdPTR,
117 winsocket->raw_socket())),
118 address_family_(address_family),
119 socket_type_(socket_type),
120 mu_(mu),
121 winsocket_(std::move(winsocket)),
122 read_buf_(grpc_empty_slice()),
123 write_buf_(grpc_empty_slice()),
124 outer_read_closure_([this]() { OnIocpReadable(); }),
__anonfa030b620302() 125 outer_write_closure_([this]() { OnIocpWriteable(); }),
__anonfa030b620402() 126 on_tcp_connect_locked_([this]() { OnTcpConnect(); }),
127 event_engine_(event_engine) {}
128
~GrpcPolledFdWindows()129 ~GrpcPolledFdWindows() override {
130 GRPC_ARES_RESOLVER_TRACE_LOG(
131 "fd:|%s| ~GrpcPolledFdWindows shutdown_called_: %d ", GetName(),
132 shutdown_called_);
133 grpc_core::CSliceUnref(read_buf_);
134 grpc_core::CSliceUnref(write_buf_);
135 GPR_ASSERT(read_closure_ == nullptr);
136 GPR_ASSERT(write_closure_ == nullptr);
137 if (!shutdown_called_) {
138 winsocket_->Shutdown(DEBUG_LOCATION, "~GrpcPolledFdWindows");
139 }
140 }
141
RegisterForOnReadableLocked(absl::AnyInvocable<void (absl::Status)> read_closure)142 void RegisterForOnReadableLocked(
143 absl::AnyInvocable<void(absl::Status)> read_closure) override {
144 GPR_ASSERT(read_closure_ == nullptr);
145 read_closure_ = std::move(read_closure);
146 grpc_core::CSliceUnref(read_buf_);
147 GPR_ASSERT(!read_buf_has_data_);
148 read_buf_ = GRPC_SLICE_MALLOC(kReadBufferSize);
149 if (connect_done_) {
150 ContinueRegisterForOnReadableLocked();
151 } else {
152 GPR_ASSERT(pending_continue_register_for_on_readable_locked_ == false);
153 pending_continue_register_for_on_readable_locked_ = true;
154 }
155 }
156
RegisterForOnWriteableLocked(absl::AnyInvocable<void (absl::Status)> write_closure)157 void RegisterForOnWriteableLocked(
158 absl::AnyInvocable<void(absl::Status)> write_closure) override {
159 if (socket_type_ == SOCK_DGRAM) {
160 GRPC_ARES_RESOLVER_TRACE_LOG(
161 "fd:|%s| RegisterForOnWriteableLocked called", GetName());
162 } else {
163 GPR_ASSERT(socket_type_ == SOCK_STREAM);
164 GRPC_ARES_RESOLVER_TRACE_LOG(
165 "fd:|%s| RegisterForOnWriteableLocked called tcp_write_state_: %d "
166 "connect_done_: %d",
167 GetName(), tcp_write_state_, connect_done_);
168 }
169 GPR_ASSERT(write_closure_ == nullptr);
170 write_closure_ = std::move(write_closure);
171 if (!connect_done_) {
172 GPR_ASSERT(!pending_continue_register_for_on_writeable_locked_);
173 pending_continue_register_for_on_writeable_locked_ = true;
174 } else {
175 ContinueRegisterForOnWriteableLocked();
176 }
177 }
178
IsFdStillReadableLocked()179 bool IsFdStillReadableLocked() override { return read_buf_has_data_; }
180
ShutdownLocked(absl::Status error)181 bool ShutdownLocked(absl::Status error) override {
182 GPR_ASSERT(!shutdown_called_);
183 if (!absl::IsCancelled(error)) {
184 return false;
185 }
186 GRPC_ARES_RESOLVER_TRACE_LOG("fd:|%s| ShutdownLocked", GetName());
187 shutdown_called_ = true;
188 // The socket is disconnected and closed here since this is an external
189 // cancel request, e.g. a timeout. c-ares shouldn't do anything on the
190 // socket after this point except calling close which should then destroy
191 // the GrpcPolledFdWindows object.
192 winsocket_->Shutdown(DEBUG_LOCATION, "GrpcPolledFdWindows::ShutdownLocked");
193 return true;
194 }
195
GetWrappedAresSocketLocked()196 ares_socket_t GetWrappedAresSocketLocked() override {
197 return winsocket_->raw_socket();
198 }
199
GetName() const200 const char* GetName() const override { return name_.c_str(); }
201
RecvFrom(WSAErrorContext * wsa_error_ctx,void * data,ares_socket_t data_len,int,struct sockaddr * from,ares_socklen_t * from_len)202 ares_ssize_t RecvFrom(WSAErrorContext* wsa_error_ctx, void* data,
203 ares_socket_t data_len, int /* flags */,
204 struct sockaddr* from, ares_socklen_t* from_len) {
205 GRPC_ARES_RESOLVER_TRACE_LOG(
206 "fd:|%s| RecvFrom called read_buf_has_data:%d Current read buf "
207 "length:|%d|",
208 GetName(), read_buf_has_data_, GRPC_SLICE_LENGTH(read_buf_));
209 if (!read_buf_has_data_) {
210 wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
211 return -1;
212 }
213 ares_ssize_t bytes_read = 0;
214 for (size_t i = 0; i < GRPC_SLICE_LENGTH(read_buf_) && i < data_len; i++) {
215 (static_cast<char*>(data))[i] = GRPC_SLICE_START_PTR(read_buf_)[i];
216 bytes_read++;
217 }
218 read_buf_ = grpc_slice_sub_no_ref(read_buf_, bytes_read,
219 GRPC_SLICE_LENGTH(read_buf_));
220 if (GRPC_SLICE_LENGTH(read_buf_) == 0) {
221 read_buf_has_data_ = false;
222 }
223 // c-ares overloads this recv_from virtual socket function to receive
224 // data on both UDP and TCP sockets, and from is nullptr for TCP.
225 if (from != nullptr) {
226 GPR_ASSERT(*from_len <= recv_from_source_addr_len_);
227 memcpy(from, &recv_from_source_addr_, recv_from_source_addr_len_);
228 *from_len = recv_from_source_addr_len_;
229 }
230 return bytes_read;
231 }
232
SendV(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)233 ares_ssize_t SendV(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
234 int iov_count) {
235 GRPC_ARES_RESOLVER_TRACE_LOG(
236 "fd:|%s| SendV called connect_done_:%d wsa_connect_error_:%d",
237 GetName(), connect_done_, wsa_connect_error_);
238 if (!connect_done_) {
239 wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
240 return -1;
241 }
242 if (wsa_connect_error_ != 0) {
243 wsa_error_ctx->SetWSAError(wsa_connect_error_);
244 return -1;
245 }
246 switch (socket_type_) {
247 case SOCK_DGRAM:
248 return SendVUDP(wsa_error_ctx, iov, iov_count);
249 case SOCK_STREAM:
250 return SendVTCP(wsa_error_ctx, iov, iov_count);
251 default:
252 abort();
253 }
254 }
255
Connect(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)256 int Connect(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
257 ares_socklen_t target_len) {
258 switch (socket_type_) {
259 case SOCK_DGRAM:
260 return ConnectUDP(wsa_error_ctx, target, target_len);
261 case SOCK_STREAM:
262 return ConnectTCP(wsa_error_ctx, target, target_len);
263 default:
264 grpc_core::Crash(
265 absl::StrFormat("Unknown socket_type_: %d", socket_type_));
266 }
267 }
268
269 private:
270 enum WriteState {
271 WRITE_IDLE,
272 WRITE_REQUESTED,
273 WRITE_PENDING,
274 WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY,
275 };
276
ScheduleAndNullReadClosure(absl::Status error)277 void ScheduleAndNullReadClosure(absl::Status error) {
278 event_engine_->Run([read_closure = std::move(read_closure_),
279 error]() mutable { read_closure(error); });
280 read_closure_ = nullptr;
281 }
282
ScheduleAndNullWriteClosure(absl::Status error)283 void ScheduleAndNullWriteClosure(absl::Status error) {
284 event_engine_->Run([write_closure = std::move(write_closure_),
285 error]() mutable { write_closure(error); });
286 write_closure_ = nullptr;
287 }
288
ContinueRegisterForOnReadableLocked()289 void ContinueRegisterForOnReadableLocked() {
290 GRPC_ARES_RESOLVER_TRACE_LOG(
291 "fd:|%s| ContinueRegisterForOnReadableLocked "
292 "wsa_connect_error_:%d",
293 GetName(), wsa_connect_error_);
294 GPR_ASSERT(connect_done_);
295 if (wsa_connect_error_ != 0) {
296 ScheduleAndNullReadClosure(GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
297 return;
298 }
299 WSABUF buffer;
300 buffer.buf = reinterpret_cast<char*>(GRPC_SLICE_START_PTR(read_buf_));
301 buffer.len = GRPC_SLICE_LENGTH(read_buf_);
302 recv_from_source_addr_len_ = sizeof(recv_from_source_addr_);
303 DWORD flags = 0;
304 winsocket_->NotifyOnRead(&outer_read_closure_);
305 if (WSARecvFrom(winsocket_->raw_socket(), &buffer, 1, nullptr, &flags,
306 reinterpret_cast<sockaddr*>(recv_from_source_addr_),
307 &recv_from_source_addr_len_,
308 winsocket_->read_info()->overlapped(), nullptr) != 0) {
309 int wsa_last_error = WSAGetLastError();
310 char* msg = gpr_format_message(wsa_last_error);
311 GRPC_ARES_RESOLVER_TRACE_LOG(
312 "fd:|%s| ContinueRegisterForOnReadableLocked WSARecvFrom error "
313 "code:|%d| "
314 "msg:|%s|",
315 GetName(), wsa_last_error, msg);
316 gpr_free(msg);
317 if (wsa_last_error != WSA_IO_PENDING) {
318 winsocket_->UnregisterReadCallback();
319 ScheduleAndNullReadClosure(
320 GRPC_WSA_ERROR(wsa_last_error, "WSARecvFrom"));
321 return;
322 }
323 }
324 }
325
ContinueRegisterForOnWriteableLocked()326 void ContinueRegisterForOnWriteableLocked() {
327 GRPC_ARES_RESOLVER_TRACE_LOG(
328 "fd:|%s| ContinueRegisterForOnWriteableLocked "
329 "wsa_connect_error_:%d",
330 GetName(), wsa_connect_error_);
331 GPR_ASSERT(connect_done_);
332 if (wsa_connect_error_ != 0) {
333 ScheduleAndNullWriteClosure(
334 GRPC_WSA_ERROR(wsa_connect_error_, "connect"));
335 return;
336 }
337 if (socket_type_ == SOCK_DGRAM) {
338 ScheduleAndNullWriteClosure(absl::OkStatus());
339 return;
340 }
341 GPR_ASSERT(socket_type_ == SOCK_STREAM);
342 int wsa_error_code = 0;
343 switch (tcp_write_state_) {
344 case WRITE_IDLE:
345 ScheduleAndNullWriteClosure(absl::OkStatus());
346 break;
347 case WRITE_REQUESTED:
348 tcp_write_state_ = WRITE_PENDING;
349 winsocket_->NotifyOnWrite(&outer_write_closure_);
350 if (SendWriteBuf(nullptr, winsocket_->write_info()->overlapped(),
351 &wsa_error_code) != 0) {
352 winsocket_->UnregisterWriteCallback();
353 ScheduleAndNullWriteClosure(
354 GRPC_WSA_ERROR(wsa_error_code, "WSASend (overlapped)"));
355 return;
356 }
357 break;
358 case WRITE_PENDING:
359 case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
360 grpc_core::Crash(
361 absl::StrFormat("Invalid tcp_write_state_: %d", tcp_write_state_));
362 }
363 }
364
SendWriteBuf(LPDWORD bytes_sent_ptr,LPWSAOVERLAPPED overlapped,int * wsa_error_code)365 int SendWriteBuf(LPDWORD bytes_sent_ptr, LPWSAOVERLAPPED overlapped,
366 int* wsa_error_code) {
367 WSABUF buf;
368 buf.len = GRPC_SLICE_LENGTH(write_buf_);
369 buf.buf = reinterpret_cast<char*>(GRPC_SLICE_START_PTR(write_buf_));
370 DWORD flags = 0;
371 int out = WSASend(winsocket_->raw_socket(), &buf, 1, bytes_sent_ptr, flags,
372 overlapped, nullptr);
373 *wsa_error_code = WSAGetLastError();
374 GRPC_ARES_RESOLVER_TRACE_LOG(
375 "fd:|%s| SendWriteBuf WSASend buf.len:%d *bytes_sent_ptr:%d "
376 "overlapped:%p "
377 "return:%d *wsa_error_code:%d",
378 GetName(), buf.len, bytes_sent_ptr != nullptr ? *bytes_sent_ptr : 0,
379 overlapped, out, *wsa_error_code);
380 return out;
381 }
382
SendVUDP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)383 ares_ssize_t SendVUDP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
384 int iov_count) {
385 // c-ares doesn't handle retryable errors on writes of UDP sockets.
386 // Therefore, the sendv handler for UDP sockets must only attempt
387 // to write everything inline.
388 GRPC_ARES_RESOLVER_TRACE_LOG("fd:|%s| SendVUDP called", GetName());
389 GPR_ASSERT(GRPC_SLICE_LENGTH(write_buf_) == 0);
390 grpc_core::CSliceUnref(write_buf_);
391 write_buf_ = FlattenIovec(iov, iov_count);
392 DWORD bytes_sent = 0;
393 int wsa_error_code = 0;
394 if (SendWriteBuf(&bytes_sent, nullptr, &wsa_error_code) != 0) {
395 grpc_core::CSliceUnref(write_buf_);
396 write_buf_ = grpc_empty_slice();
397 wsa_error_ctx->SetWSAError(wsa_error_code);
398 char* msg = gpr_format_message(wsa_error_code);
399 GRPC_ARES_RESOLVER_TRACE_LOG(
400 "fd:|%s| SendVUDP SendWriteBuf error code:%d msg:|%s|", GetName(),
401 wsa_error_code, msg);
402 gpr_free(msg);
403 return -1;
404 }
405 write_buf_ = grpc_slice_sub_no_ref(write_buf_, bytes_sent,
406 GRPC_SLICE_LENGTH(write_buf_));
407 return bytes_sent;
408 }
409
SendVTCP(WSAErrorContext * wsa_error_ctx,const struct iovec * iov,int iov_count)410 ares_ssize_t SendVTCP(WSAErrorContext* wsa_error_ctx, const struct iovec* iov,
411 int iov_count) {
412 // The "sendv" handler on TCP sockets buffers up write
413 // requests and returns an artificial WSAEWOULDBLOCK. Writing that buffer
414 // out in the background, and making further send progress in general, will
415 // happen as long as c-ares continues to show interest in writeability on
416 // this fd.
417 GRPC_ARES_RESOLVER_TRACE_LOG("fd:|%s| SendVTCP called tcp_write_state_:%d",
418 GetName(), tcp_write_state_);
419 switch (tcp_write_state_) {
420 case WRITE_IDLE:
421 tcp_write_state_ = WRITE_REQUESTED;
422 grpc_core::CSliceUnref(write_buf_);
423 write_buf_ = FlattenIovec(iov, iov_count);
424 wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
425 return -1;
426 case WRITE_REQUESTED:
427 case WRITE_PENDING:
428 wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
429 return -1;
430 case WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY:
431 // c-ares is retrying a send on data that we previously returned
432 // WSAEWOULDBLOCK for, but then subsequently wrote out in the
433 // background. Right now, we assume that c-ares is retrying the same
434 // send again. If c-ares still needs to send even more data, we'll get
435 // to it eventually.
436 grpc_slice currently_attempted = FlattenIovec(iov, iov_count);
437 GPR_ASSERT(GRPC_SLICE_LENGTH(currently_attempted) >=
438 GRPC_SLICE_LENGTH(write_buf_));
439 ares_ssize_t total_sent = 0;
440 for (size_t i = 0; i < GRPC_SLICE_LENGTH(write_buf_); i++) {
441 GPR_ASSERT(GRPC_SLICE_START_PTR(currently_attempted)[i] ==
442 GRPC_SLICE_START_PTR(write_buf_)[i]);
443 total_sent++;
444 }
445 grpc_core::CSliceUnref(currently_attempted);
446 tcp_write_state_ = WRITE_IDLE;
447 return total_sent;
448 }
449 grpc_core::Crash(
450 absl::StrFormat("Unknown tcp_write_state_: %d", tcp_write_state_));
451 }
452
OnTcpConnect()453 void OnTcpConnect() {
454 grpc_core::MutexLock lock(mu_);
455 GRPC_ARES_RESOLVER_TRACE_LOG(
456 "fd:%s InnerOnTcpConnectLocked "
457 "pending_register_for_readable:%d"
458 " pending_register_for_writeable:%d",
459 GetName(), pending_continue_register_for_on_readable_locked_,
460 pending_continue_register_for_on_writeable_locked_);
461 GPR_ASSERT(!connect_done_);
462 connect_done_ = true;
463 GPR_ASSERT(wsa_connect_error_ == 0);
464 if (shutdown_called_) {
465 wsa_connect_error_ = WSA_OPERATION_ABORTED;
466 } else {
467 DWORD transferred_bytes = 0;
468 DWORD flags;
469 BOOL wsa_success = WSAGetOverlappedResult(
470 winsocket_->raw_socket(), winsocket_->write_info()->overlapped(),
471 &transferred_bytes, FALSE, &flags);
472 GPR_ASSERT(transferred_bytes == 0);
473 if (!wsa_success) {
474 wsa_connect_error_ = WSAGetLastError();
475 char* msg = gpr_format_message(wsa_connect_error_);
476 GRPC_ARES_RESOLVER_TRACE_LOG(
477 "fd:%s InnerOnTcpConnectLocked WSA overlapped result code:%d "
478 "msg:|%s|",
479 GetName(), wsa_connect_error_, msg);
480 gpr_free(msg);
481 }
482 }
483 if (pending_continue_register_for_on_readable_locked_) {
484 ContinueRegisterForOnReadableLocked();
485 }
486 if (pending_continue_register_for_on_writeable_locked_) {
487 ContinueRegisterForOnWriteableLocked();
488 }
489 }
490
ConnectUDP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)491 int ConnectUDP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
492 ares_socklen_t target_len) {
493 GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s ConnectUDP", GetName());
494 GPR_ASSERT(!connect_done_);
495 GPR_ASSERT(wsa_connect_error_ == 0);
496 SOCKET s = winsocket_->raw_socket();
497 int out =
498 WSAConnect(s, target, target_len, nullptr, nullptr, nullptr, nullptr);
499 wsa_connect_error_ = WSAGetLastError();
500 wsa_error_ctx->SetWSAError(wsa_connect_error_);
501 connect_done_ = true;
502 char* msg = gpr_format_message(wsa_connect_error_);
503 GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s WSAConnect error code:|%d| msg:|%s|",
504 GetName(), wsa_connect_error_, msg);
505 gpr_free(msg);
506 // c-ares expects a posix-style connect API
507 return out == 0 ? 0 : -1;
508 }
509
ConnectTCP(WSAErrorContext * wsa_error_ctx,const struct sockaddr * target,ares_socklen_t target_len)510 int ConnectTCP(WSAErrorContext* wsa_error_ctx, const struct sockaddr* target,
511 ares_socklen_t target_len) {
512 GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s ConnectTCP", GetName());
513 LPFN_CONNECTEX ConnectEx;
514 GUID guid = WSAID_CONNECTEX;
515 DWORD ioctl_num_bytes;
516 SOCKET s = winsocket_->raw_socket();
517 if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
518 &ConnectEx, sizeof(ConnectEx), &ioctl_num_bytes, nullptr,
519 nullptr) != 0) {
520 int wsa_last_error = WSAGetLastError();
521 wsa_error_ctx->SetWSAError(wsa_last_error);
522 char* msg = gpr_format_message(wsa_last_error);
523 GRPC_ARES_RESOLVER_TRACE_LOG(
524 "fd:%s WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER) error code:%d "
525 "msg:|%s|",
526 GetName(), wsa_last_error, msg);
527 gpr_free(msg);
528 connect_done_ = true;
529 wsa_connect_error_ = wsa_last_error;
530 return -1;
531 }
532 grpc_resolved_address wildcard4_addr;
533 grpc_resolved_address wildcard6_addr;
534 grpc_sockaddr_make_wildcards(0, &wildcard4_addr, &wildcard6_addr);
535 grpc_resolved_address* local_address = nullptr;
536 if (address_family_ == AF_INET) {
537 local_address = &wildcard4_addr;
538 } else {
539 local_address = &wildcard6_addr;
540 }
541 if (bind(s, reinterpret_cast<struct sockaddr*>(local_address->addr),
542 static_cast<int>(local_address->len)) != 0) {
543 int wsa_last_error = WSAGetLastError();
544 wsa_error_ctx->SetWSAError(wsa_last_error);
545 char* msg = gpr_format_message(wsa_last_error);
546 GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s bind error code:%d msg:|%s|",
547 GetName(), wsa_last_error, msg);
548 gpr_free(msg);
549 connect_done_ = true;
550 wsa_connect_error_ = wsa_last_error;
551 return -1;
552 }
553 int out = 0;
554 // Register an async OnTcpConnect callback here since it is required by the
555 // WinSocket API.
556 winsocket_->NotifyOnWrite(&on_tcp_connect_locked_);
557 if (ConnectEx(s, target, target_len, nullptr, 0, nullptr,
558 winsocket_->write_info()->overlapped()) == 0) {
559 out = -1;
560 int wsa_last_error = WSAGetLastError();
561 wsa_error_ctx->SetWSAError(wsa_last_error);
562 char* msg = gpr_format_message(wsa_last_error);
563 GRPC_ARES_RESOLVER_TRACE_LOG("fd:%s ConnectEx error code:%d msg:|%s|",
564 GetName(), wsa_last_error, msg);
565 gpr_free(msg);
566 if (wsa_last_error == WSA_IO_PENDING) {
567 // c-ares only understands WSAEINPROGRESS and EWOULDBLOCK error codes on
568 // connect, but an async connect on IOCP socket will give
569 // WSA_IO_PENDING, so we need to convert.
570 wsa_error_ctx->SetWSAError(WSAEWOULDBLOCK);
571 } else {
572 winsocket_->UnregisterWriteCallback();
573 // By returning a non-retryable error to c-ares at this point,
574 // we're aborting the possibility of any future operations on this fd.
575 connect_done_ = true;
576 wsa_connect_error_ = wsa_last_error;
577 return -1;
578 }
579 }
580 return out;
581 }
582
583 // TODO(apolcyn): improve this error handling to be less conversative.
584 // An e.g. ECONNRESET error here should result in errors when
585 // c-ares reads from this socket later, but it shouldn't necessarily cancel
586 // the entire resolution attempt. Doing so will allow the "inject broken
587 // nameserver list" test to pass on Windows.
OnIocpReadable()588 void OnIocpReadable() {
589 grpc_core::MutexLock lock(mu_);
590 absl::Status error;
591 if (winsocket_->read_info()->result().wsa_error != 0) {
592 // WSAEMSGSIZE would be due to receiving more data
593 // than our read buffer's fixed capacity. Assume that
594 // the connection is TCP and read the leftovers
595 // in subsequent c-ares reads.
596 if (winsocket_->read_info()->result().wsa_error != WSAEMSGSIZE) {
597 error = GRPC_WSA_ERROR(winsocket_->read_info()->result().wsa_error,
598 "OnIocpReadableInner");
599 GRPC_ARES_RESOLVER_TRACE_LOG(
600 "fd:|%s| OnIocpReadableInner winsocket_->read_info.wsa_error "
601 "code:|%d| msg:|%s|",
602 GetName(), winsocket_->read_info()->result().wsa_error,
603 grpc_core::StatusToString(error).c_str());
604 }
605 }
606 if (error.ok()) {
607 read_buf_ = grpc_slice_sub_no_ref(
608 read_buf_, 0, winsocket_->read_info()->result().bytes_transferred);
609 read_buf_has_data_ = true;
610 } else {
611 grpc_core::CSliceUnref(read_buf_);
612 read_buf_ = grpc_empty_slice();
613 }
614 GRPC_ARES_RESOLVER_TRACE_LOG(
615 "fd:|%s| OnIocpReadable finishing. read buf length now:|%d|", GetName(),
616 GRPC_SLICE_LENGTH(read_buf_));
617 ScheduleAndNullReadClosure(error);
618 }
619
OnIocpWriteable()620 void OnIocpWriteable() {
621 grpc_core::MutexLock lock(mu_);
622 GRPC_ARES_RESOLVER_TRACE_LOG("OnIocpWriteableInner. fd:|%s|", GetName());
623 GPR_ASSERT(socket_type_ == SOCK_STREAM);
624 absl::Status error;
625 if (winsocket_->write_info()->result().wsa_error != 0) {
626 error = GRPC_WSA_ERROR(winsocket_->write_info()->result().wsa_error,
627 "OnIocpWriteableInner");
628 GRPC_ARES_RESOLVER_TRACE_LOG(
629 "fd:|%s| OnIocpWriteableInner. winsocket_->write_info.wsa_error "
630 "code:|%d| msg:|%s|",
631 GetName(), winsocket_->write_info()->result().wsa_error,
632 grpc_core::StatusToString(error).c_str());
633 }
634 GPR_ASSERT(tcp_write_state_ == WRITE_PENDING);
635 if (error.ok()) {
636 tcp_write_state_ = WRITE_WAITING_FOR_VERIFICATION_UPON_RETRY;
637 write_buf_ = grpc_slice_sub_no_ref(
638 write_buf_, 0, winsocket_->write_info()->result().bytes_transferred);
639 GRPC_ARES_RESOLVER_TRACE_LOG(
640 "fd:|%s| OnIocpWriteableInner. bytes transferred:%d", GetName(),
641 winsocket_->write_info()->result().bytes_transferred);
642 } else {
643 grpc_core::CSliceUnref(write_buf_);
644 write_buf_ = grpc_empty_slice();
645 }
646 ScheduleAndNullWriteClosure(error);
647 }
648
649 const std::string name_;
650 const int address_family_;
651 const int socket_type_;
652 grpc_core::Mutex* mu_;
653 std::unique_ptr<WinSocket> winsocket_;
654 char recv_from_source_addr_[kRecvFromSourceAddrSize];
655 ares_socklen_t recv_from_source_addr_len_;
656 grpc_slice read_buf_;
657 bool read_buf_has_data_ = false;
658 grpc_slice write_buf_;
659 absl::AnyInvocable<void(absl::Status)> read_closure_;
660 absl::AnyInvocable<void(absl::Status)> write_closure_;
661 AnyInvocableClosure outer_read_closure_;
662 AnyInvocableClosure outer_write_closure_;
663 bool shutdown_called_ = false;
664 // State related to TCP sockets
665 AnyInvocableClosure on_tcp_connect_locked_;
666 bool connect_done_ = false;
667 int wsa_connect_error_ = 0;
668 WriteState tcp_write_state_ = WRITE_IDLE;
669 // We don't run register_for_{readable,writeable} logic until
670 // a socket is connected. In the interim, we queue readable/writeable
671 // registrations with the following state.
672 bool pending_continue_register_for_on_readable_locked_ = false;
673 bool pending_continue_register_for_on_writeable_locked_ = false;
674 // This pointer is initialized from the stored pointer inside the shared
675 // pointer owned by the AresResolver and should be valid at the time of use.
676 EventEngine* event_engine_;
677 };
678
679 // These virtual socket functions are called from within the c-ares
680 // library. These methods generally dispatch those socket calls to the
681 // appropriate methods. The virtual "socket" and "close" methods are
682 // special and instead create/add and remove/destroy GrpcPolledFdWindows
683 // objects.
684 class CustomSockFuncs {
685 public:
Socket(int af,int type,int protocol,void * user_data)686 static ares_socket_t Socket(int af, int type, int protocol, void* user_data) {
687 if (type != SOCK_DGRAM && type != SOCK_STREAM) {
688 GRPC_ARES_RESOLVER_TRACE_LOG("Socket called with invalid socket type:%d",
689 type);
690 return INVALID_SOCKET;
691 }
692 GrpcPolledFdFactoryWindows* self =
693 static_cast<GrpcPolledFdFactoryWindows*>(user_data);
694 SOCKET s = WSASocket(af, type, protocol, nullptr, 0,
695 IOCP::GetDefaultSocketFlags());
696 if (s == INVALID_SOCKET) {
697 GRPC_ARES_RESOLVER_TRACE_LOG(
698 "WSASocket failed with params af:%d type:%d protocol:%d", af, type,
699 protocol);
700 return INVALID_SOCKET;
701 }
702 if (type == SOCK_STREAM) {
703 absl::Status error = PrepareSocket(s);
704 if (!error.ok()) {
705 GRPC_ARES_RESOLVER_TRACE_LOG("WSAIoctl failed with error: %s",
706 grpc_core::StatusToString(error).c_str());
707 return INVALID_SOCKET;
708 }
709 }
710 auto polled_fd = std::make_unique<GrpcPolledFdWindows>(
711 self->iocp_->Watch(s), self->mu_, af, type, self->event_engine_);
712 GRPC_ARES_RESOLVER_TRACE_LOG(
713 "fd:|%s| created with params af:%d type:%d protocol:%d",
714 polled_fd->GetName(), af, type, protocol);
715 GPR_ASSERT(self->sockets_.insert({s, std::move(polled_fd)}).second);
716 return s;
717 }
718
Connect(ares_socket_t as,const struct sockaddr * target,ares_socklen_t target_len,void * user_data)719 static int Connect(ares_socket_t as, const struct sockaddr* target,
720 ares_socklen_t target_len, void* user_data) {
721 WSAErrorContext wsa_error_ctx;
722 GrpcPolledFdFactoryWindows* self =
723 static_cast<GrpcPolledFdFactoryWindows*>(user_data);
724 auto it = self->sockets_.find(as);
725 GPR_ASSERT(it != self->sockets_.end());
726 return it->second->Connect(&wsa_error_ctx, target, target_len);
727 }
728
SendV(ares_socket_t as,const struct iovec * iov,int iovec_count,void * user_data)729 static ares_ssize_t SendV(ares_socket_t as, const struct iovec* iov,
730 int iovec_count, void* user_data) {
731 WSAErrorContext wsa_error_ctx;
732 GrpcPolledFdFactoryWindows* self =
733 static_cast<GrpcPolledFdFactoryWindows*>(user_data);
734 auto it = self->sockets_.find(as);
735 GPR_ASSERT(it != self->sockets_.end());
736 return it->second->SendV(&wsa_error_ctx, iov, iovec_count);
737 }
738
RecvFrom(ares_socket_t as,void * data,size_t data_len,int flags,struct sockaddr * from,ares_socklen_t * from_len,void * user_data)739 static ares_ssize_t RecvFrom(ares_socket_t as, void* data, size_t data_len,
740 int flags, struct sockaddr* from,
741 ares_socklen_t* from_len, void* user_data) {
742 WSAErrorContext wsa_error_ctx;
743 GrpcPolledFdFactoryWindows* self =
744 static_cast<GrpcPolledFdFactoryWindows*>(user_data);
745 auto it = self->sockets_.find(as);
746 GPR_ASSERT(it != self->sockets_.end());
747 return it->second->RecvFrom(&wsa_error_ctx, data, data_len, flags, from,
748 from_len);
749 }
750
CloseSocket(SOCKET s,void *)751 static int CloseSocket(SOCKET s, void*) {
752 GRPC_ARES_RESOLVER_TRACE_LOG("c-ares socket: %d CloseSocket", s);
753 return 0;
754 }
755 };
756
757 // Adapter to hold the ownership of GrpcPolledFdWindows internally.
758 class GrpcPolledFdWrapper : public GrpcPolledFd {
759 public:
GrpcPolledFdWrapper(GrpcPolledFdWindows * polled_fd)760 explicit GrpcPolledFdWrapper(GrpcPolledFdWindows* polled_fd)
761 : polled_fd_(polled_fd) {}
762
RegisterForOnReadableLocked(absl::AnyInvocable<void (absl::Status)> read_closure)763 void RegisterForOnReadableLocked(
764 absl::AnyInvocable<void(absl::Status)> read_closure) override {
765 polled_fd_->RegisterForOnReadableLocked(std::move(read_closure));
766 }
767
RegisterForOnWriteableLocked(absl::AnyInvocable<void (absl::Status)> write_closure)768 void RegisterForOnWriteableLocked(
769 absl::AnyInvocable<void(absl::Status)> write_closure) override {
770 polled_fd_->RegisterForOnWriteableLocked(std::move(write_closure));
771 }
772
IsFdStillReadableLocked()773 bool IsFdStillReadableLocked() override {
774 return polled_fd_->IsFdStillReadableLocked();
775 }
776
ShutdownLocked(absl::Status error)777 bool ShutdownLocked(absl::Status error) override {
778 return polled_fd_->ShutdownLocked(error);
779 }
780
GetWrappedAresSocketLocked()781 ares_socket_t GetWrappedAresSocketLocked() override {
782 return polled_fd_->GetWrappedAresSocketLocked();
783 }
784
GetName() const785 const char* GetName() const override { return polled_fd_->GetName(); }
786
787 private:
788 GrpcPolledFdWindows* polled_fd_;
789 };
790
GrpcPolledFdFactoryWindows(IOCP * iocp)791 GrpcPolledFdFactoryWindows::GrpcPolledFdFactoryWindows(IOCP* iocp)
792 : iocp_(iocp) {}
793
~GrpcPolledFdFactoryWindows()794 GrpcPolledFdFactoryWindows::~GrpcPolledFdFactoryWindows() {}
795
Initialize(grpc_core::Mutex * mutex,EventEngine * event_engine)796 void GrpcPolledFdFactoryWindows::Initialize(grpc_core::Mutex* mutex,
797 EventEngine* event_engine) {
798 mu_ = mutex;
799 event_engine_ = event_engine;
800 }
801
NewGrpcPolledFdLocked(ares_socket_t as)802 std::unique_ptr<GrpcPolledFd> GrpcPolledFdFactoryWindows::NewGrpcPolledFdLocked(
803 ares_socket_t as) {
804 auto it = sockets_.find(as);
805 GPR_ASSERT(it != sockets_.end());
806 return std::make_unique<GrpcPolledFdWrapper>(it->second.get());
807 }
808
ConfigureAresChannelLocked(ares_channel channel)809 void GrpcPolledFdFactoryWindows::ConfigureAresChannelLocked(
810 ares_channel channel) {
811 static const struct ares_socket_functions kCustomSockFuncs = {
812 /*asocket=*/&CustomSockFuncs::Socket,
813 /*aclose=*/&CustomSockFuncs::CloseSocket,
814 /*aconnect=*/&CustomSockFuncs::Connect,
815 /*arecvfrom=*/&CustomSockFuncs::RecvFrom,
816 /*asendv=*/&CustomSockFuncs::SendV,
817 };
818 ares_set_socket_functions(channel, &kCustomSockFuncs, this);
819 }
820
821 } // namespace experimental
822 } // namespace grpc_event_engine
823
824 #endif // GRPC_ARES == 1 && defined(GRPC_WINDOWS_SOCKET_ARES_EV_DRIVER)
825