1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/tcp_client_socket.h"
6
7 #include <memory>
8 #include <utility>
9
10 #include "base/check_op.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/callback_helpers.h"
13 #include "base/memory/ptr_util.h"
14 #include "base/metrics/histogram_macros.h"
15 #include "base/notreached.h"
16 #include "base/time/time.h"
17 #include "net/base/features.h"
18 #include "net/base/io_buffer.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/net_errors.h"
21 #include "net/nqe/network_quality_estimator.h"
22 #include "net/socket/socket_performance_watcher.h"
23 #include "net/traffic_annotation/network_traffic_annotation.h"
24
25 #if defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
26 #include "base/power_monitor/power_monitor.h"
27 #endif
28
29 namespace net {
30
31 class NetLogWithSource;
32
TCPClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,net::NetLog * net_log,const net::NetLogSource & source,handles::NetworkHandle network)33 TCPClientSocket::TCPClientSocket(
34 const AddressList& addresses,
35 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
36 NetworkQualityEstimator* network_quality_estimator,
37 net::NetLog* net_log,
38 const net::NetLogSource& source,
39 handles::NetworkHandle network)
40 : TCPClientSocket(
41 std::make_unique<TCPSocket>(std::move(socket_performance_watcher),
42 net_log,
43 source),
44 addresses,
45 -1 /* current_address_index */,
46 nullptr /* bind_address */,
47 network_quality_estimator,
48 network) {}
49
TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket,const IPEndPoint & peer_address)50 TCPClientSocket::TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket,
51 const IPEndPoint& peer_address)
52 : TCPClientSocket(std::move(connected_socket),
53 AddressList(peer_address),
54 0 /* current_address_index */,
55 nullptr /* bind_address */,
56 // TODO(https://crbug.com/1123197: Pass non-null
57 // NetworkQualityEstimator
58 nullptr /* network_quality_estimator */,
59 handles::kInvalidNetworkHandle) {}
60
TCPClientSocket(std::unique_ptr<TCPSocket> unconnected_socket,const AddressList & addresses,std::unique_ptr<IPEndPoint> bound_address,NetworkQualityEstimator * network_quality_estimator)61 TCPClientSocket::TCPClientSocket(
62 std::unique_ptr<TCPSocket> unconnected_socket,
63 const AddressList& addresses,
64 std::unique_ptr<IPEndPoint> bound_address,
65 NetworkQualityEstimator* network_quality_estimator)
66 : TCPClientSocket(std::move(unconnected_socket),
67 addresses,
68 -1 /* current_address_index */,
69 std::move(bound_address),
70 network_quality_estimator,
71 handles::kInvalidNetworkHandle) {}
72
~TCPClientSocket()73 TCPClientSocket::~TCPClientSocket() {
74 Disconnect();
75 #if defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
76 base::PowerMonitor::RemovePowerSuspendObserver(this);
77 #endif // defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
78 }
79
CreateFromBoundSocket(std::unique_ptr<TCPSocket> bound_socket,const AddressList & addresses,const IPEndPoint & bound_address,NetworkQualityEstimator * network_quality_estimator)80 std::unique_ptr<TCPClientSocket> TCPClientSocket::CreateFromBoundSocket(
81 std::unique_ptr<TCPSocket> bound_socket,
82 const AddressList& addresses,
83 const IPEndPoint& bound_address,
84 NetworkQualityEstimator* network_quality_estimator) {
85 return base::WrapUnique(new TCPClientSocket(
86 std::move(bound_socket), addresses, -1 /* current_address_index */,
87 std::make_unique<IPEndPoint>(bound_address), network_quality_estimator,
88 handles::kInvalidNetworkHandle));
89 }
90
Bind(const IPEndPoint & address)91 int TCPClientSocket::Bind(const IPEndPoint& address) {
92 if (current_address_index_ >= 0 || bind_address_) {
93 // Cannot bind the socket if we are already connected or connecting.
94 NOTREACHED();
95 return ERR_UNEXPECTED;
96 }
97
98 int result = OK;
99 if (!socket_->IsValid()) {
100 result = OpenSocket(address.GetFamily());
101 if (result != OK)
102 return result;
103 }
104
105 result = socket_->Bind(address);
106 if (result != OK)
107 return result;
108
109 bind_address_ = std::make_unique<IPEndPoint>(address);
110 return OK;
111 }
112
SetKeepAlive(bool enable,int delay)113 bool TCPClientSocket::SetKeepAlive(bool enable, int delay) {
114 return socket_->SetKeepAlive(enable, delay);
115 }
116
SetNoDelay(bool no_delay)117 bool TCPClientSocket::SetNoDelay(bool no_delay) {
118 return socket_->SetNoDelay(no_delay);
119 }
120
SetBeforeConnectCallback(const BeforeConnectCallback & before_connect_callback)121 void TCPClientSocket::SetBeforeConnectCallback(
122 const BeforeConnectCallback& before_connect_callback) {
123 DCHECK_EQ(CONNECT_STATE_NONE, next_connect_state_);
124 before_connect_callback_ = before_connect_callback;
125 }
126
Connect(CompletionOnceCallback callback)127 int TCPClientSocket::Connect(CompletionOnceCallback callback) {
128 DCHECK(!callback.is_null());
129
130 // If connecting or already connected, then just return OK.
131 if (socket_->IsValid() && current_address_index_ >= 0)
132 return OK;
133
134 DCHECK(!read_callback_);
135 DCHECK(!write_callback_);
136
137 if (was_disconnected_on_suspend_) {
138 Disconnect();
139 was_disconnected_on_suspend_ = false;
140 }
141
142 socket_->StartLoggingMultipleConnectAttempts(addresses_);
143
144 // We will try to connect to each address in addresses_. Start with the
145 // first one in the list.
146 next_connect_state_ = CONNECT_STATE_CONNECT;
147 current_address_index_ = 0;
148
149 int rv = DoConnectLoop(OK);
150 if (rv == ERR_IO_PENDING) {
151 connect_callback_ = std::move(callback);
152 } else {
153 socket_->EndLoggingMultipleConnectAttempts(rv);
154 }
155
156 return rv;
157 }
158
TCPClientSocket(std::unique_ptr<TCPSocket> socket,const AddressList & addresses,int current_address_index,std::unique_ptr<IPEndPoint> bind_address,NetworkQualityEstimator * network_quality_estimator,handles::NetworkHandle network)159 TCPClientSocket::TCPClientSocket(
160 std::unique_ptr<TCPSocket> socket,
161 const AddressList& addresses,
162 int current_address_index,
163 std::unique_ptr<IPEndPoint> bind_address,
164 NetworkQualityEstimator* network_quality_estimator,
165 handles::NetworkHandle network)
166 : socket_(std::move(socket)),
167 bind_address_(std::move(bind_address)),
168 addresses_(addresses),
169 current_address_index_(current_address_index),
170 network_quality_estimator_(network_quality_estimator),
171 network_(network) {
172 DCHECK(socket_);
173 if (socket_->IsValid())
174 socket_->SetDefaultOptionsForClient();
175 #if defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
176 base::PowerMonitor::AddPowerSuspendObserver(this);
177 #endif // defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
178 }
179
ReadCommon(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,bool read_if_ready)180 int TCPClientSocket::ReadCommon(IOBuffer* buf,
181 int buf_len,
182 CompletionOnceCallback callback,
183 bool read_if_ready) {
184 DCHECK(!callback.is_null());
185 DCHECK(read_callback_.is_null());
186
187 if (was_disconnected_on_suspend_)
188 return ERR_NETWORK_IO_SUSPENDED;
189
190 // |socket_| is owned by |this| and the callback won't be run once |socket_|
191 // is gone/closed. Therefore, it is safe to use base::Unretained() here.
192 CompletionOnceCallback complete_read_callback =
193 base::BindOnce(&TCPClientSocket::DidCompleteRead, base::Unretained(this));
194 int result =
195 read_if_ready
196 ? socket_->ReadIfReady(buf, buf_len,
197 std::move(complete_read_callback))
198 : socket_->Read(buf, buf_len, std::move(complete_read_callback));
199 if (result == ERR_IO_PENDING) {
200 read_callback_ = std::move(callback);
201 } else if (result > 0) {
202 was_ever_used_ = true;
203 total_received_bytes_ += result;
204 }
205
206 return result;
207 }
208
DoConnectLoop(int result)209 int TCPClientSocket::DoConnectLoop(int result) {
210 DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE);
211
212 int rv = result;
213 do {
214 ConnectState state = next_connect_state_;
215 next_connect_state_ = CONNECT_STATE_NONE;
216 switch (state) {
217 case CONNECT_STATE_CONNECT:
218 DCHECK_EQ(OK, rv);
219 rv = DoConnect();
220 break;
221 case CONNECT_STATE_CONNECT_COMPLETE:
222 rv = DoConnectComplete(rv);
223 break;
224 default:
225 NOTREACHED() << "bad state " << state;
226 rv = ERR_UNEXPECTED;
227 break;
228 }
229 } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE);
230
231 return rv;
232 }
233
DoConnect()234 int TCPClientSocket::DoConnect() {
235 DCHECK_GE(current_address_index_, 0);
236 DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
237
238 const IPEndPoint& endpoint = addresses_[current_address_index_];
239
240 if (previously_disconnected_) {
241 was_ever_used_ = false;
242 previously_disconnected_ = false;
243 }
244
245 next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE;
246
247 if (!socket_->IsValid()) {
248 int result = OpenSocket(endpoint.GetFamily());
249 if (result != OK)
250 return result;
251
252 if (bind_address_) {
253 result = socket_->Bind(*bind_address_);
254 if (result != OK) {
255 socket_->Close();
256 return result;
257 }
258 }
259 }
260
261 if (before_connect_callback_) {
262 int result = before_connect_callback_.Run();
263 DCHECK_NE(ERR_IO_PENDING, result);
264 if (result != net::OK)
265 return result;
266 }
267
268 // Notify |socket_performance_watcher_| only if the |socket_| is reused to
269 // connect to a different IP Address.
270 if (socket_->socket_performance_watcher() && current_address_index_ != 0)
271 socket_->socket_performance_watcher()->OnConnectionChanged();
272
273 start_connect_attempt_ = base::TimeTicks::Now();
274
275 // Start a timer to fail the connect attempt if it takes too long.
276 base::TimeDelta attempt_timeout = GetConnectAttemptTimeout();
277 if (!attempt_timeout.is_max()) {
278 DCHECK(!connect_attempt_timer_.IsRunning());
279 connect_attempt_timer_.Start(
280 FROM_HERE, attempt_timeout,
281 base::BindOnce(&TCPClientSocket::OnConnectAttemptTimeout,
282 base::Unretained(this)));
283 }
284
285 return ConnectInternal(endpoint);
286 }
287
DoConnectComplete(int result)288 int TCPClientSocket::DoConnectComplete(int result) {
289 if (start_connect_attempt_) {
290 EmitConnectAttemptHistograms(result);
291 start_connect_attempt_ = std::nullopt;
292 connect_attempt_timer_.Stop();
293 }
294
295 if (result == OK)
296 return OK; // Done!
297
298 // Don't try the next address if entering suspend mode.
299 if (result == ERR_NETWORK_IO_SUSPENDED)
300 return result;
301
302 // Close whatever partially connected socket we currently have.
303 DoDisconnect();
304
305 // Try to fall back to the next address in the list.
306 if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) {
307 next_connect_state_ = CONNECT_STATE_CONNECT;
308 ++current_address_index_;
309 return OK;
310 }
311
312 // Otherwise there is nothing to fall back to, so give up.
313 return result;
314 }
315
OnConnectAttemptTimeout()316 void TCPClientSocket::OnConnectAttemptTimeout() {
317 DidCompleteConnect(ERR_TIMED_OUT);
318 }
319
ConnectInternal(const IPEndPoint & endpoint)320 int TCPClientSocket::ConnectInternal(const IPEndPoint& endpoint) {
321 // |socket_| is owned by this class and the callback won't be run once
322 // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
323 return socket_->Connect(endpoint,
324 base::BindOnce(&TCPClientSocket::DidCompleteConnect,
325 base::Unretained(this)));
326 }
327
Disconnect()328 void TCPClientSocket::Disconnect() {
329 DoDisconnect();
330 current_address_index_ = -1;
331 bind_address_.reset();
332
333 // Cancel any pending callbacks. Not done in DoDisconnect() because that's
334 // called on connection failure, when the connect callback will need to be
335 // invoked.
336 was_disconnected_on_suspend_ = false;
337 connect_callback_.Reset();
338 read_callback_.Reset();
339 write_callback_.Reset();
340 }
341
DoDisconnect()342 void TCPClientSocket::DoDisconnect() {
343 if (start_connect_attempt_) {
344 EmitConnectAttemptHistograms(ERR_ABORTED);
345 start_connect_attempt_ = std::nullopt;
346 connect_attempt_timer_.Stop();
347 }
348
349 total_received_bytes_ = 0;
350
351 // If connecting or already connected, record that the socket has been
352 // disconnected.
353 previously_disconnected_ = socket_->IsValid() && current_address_index_ >= 0;
354 socket_->Close();
355
356 // Invalidate weak pointers, so if in the middle of a callback in OnSuspend,
357 // and something destroys this, no other callback is invoked.
358 weak_ptr_factory_.InvalidateWeakPtrs();
359 }
360
IsConnected() const361 bool TCPClientSocket::IsConnected() const {
362 return socket_->IsConnected();
363 }
364
IsConnectedAndIdle() const365 bool TCPClientSocket::IsConnectedAndIdle() const {
366 return socket_->IsConnectedAndIdle();
367 }
368
GetPeerAddress(IPEndPoint * address) const369 int TCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
370 return socket_->GetPeerAddress(address);
371 }
372
GetLocalAddress(IPEndPoint * address) const373 int TCPClientSocket::GetLocalAddress(IPEndPoint* address) const {
374 DCHECK(address);
375
376 if (!socket_->IsValid()) {
377 if (bind_address_) {
378 *address = *bind_address_;
379 return OK;
380 }
381 return ERR_SOCKET_NOT_CONNECTED;
382 }
383
384 return socket_->GetLocalAddress(address);
385 }
386
NetLog() const387 const NetLogWithSource& TCPClientSocket::NetLog() const {
388 return socket_->net_log();
389 }
390
WasEverUsed() const391 bool TCPClientSocket::WasEverUsed() const {
392 return was_ever_used_;
393 }
394
GetNegotiatedProtocol() const395 NextProto TCPClientSocket::GetNegotiatedProtocol() const {
396 return kProtoUnknown;
397 }
398
GetSSLInfo(SSLInfo * ssl_info)399 bool TCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
400 return false;
401 }
402
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)403 int TCPClientSocket::Read(IOBuffer* buf,
404 int buf_len,
405 CompletionOnceCallback callback) {
406 return ReadCommon(buf, buf_len, std::move(callback), /*read_if_ready=*/false);
407 }
408
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)409 int TCPClientSocket::ReadIfReady(IOBuffer* buf,
410 int buf_len,
411 CompletionOnceCallback callback) {
412 return ReadCommon(buf, buf_len, std::move(callback), /*read_if_ready=*/true);
413 }
414
CancelReadIfReady()415 int TCPClientSocket::CancelReadIfReady() {
416 DCHECK(read_callback_);
417 read_callback_.Reset();
418 return socket_->CancelReadIfReady();
419 }
420
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)421 int TCPClientSocket::Write(
422 IOBuffer* buf,
423 int buf_len,
424 CompletionOnceCallback callback,
425 const NetworkTrafficAnnotationTag& traffic_annotation) {
426 DCHECK(!callback.is_null());
427 DCHECK(write_callback_.is_null());
428
429 if (was_disconnected_on_suspend_)
430 return ERR_NETWORK_IO_SUSPENDED;
431
432 // |socket_| is owned by this class and the callback won't be run once
433 // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
434 CompletionOnceCallback complete_write_callback = base::BindOnce(
435 &TCPClientSocket::DidCompleteWrite, base::Unretained(this));
436 int result = socket_->Write(buf, buf_len, std::move(complete_write_callback),
437 traffic_annotation);
438 if (result == ERR_IO_PENDING) {
439 write_callback_ = std::move(callback);
440 } else if (result > 0) {
441 was_ever_used_ = true;
442 }
443
444 return result;
445 }
446
SetReceiveBufferSize(int32_t size)447 int TCPClientSocket::SetReceiveBufferSize(int32_t size) {
448 return socket_->SetReceiveBufferSize(size);
449 }
450
SetSendBufferSize(int32_t size)451 int TCPClientSocket::SetSendBufferSize(int32_t size) {
452 return socket_->SetSendBufferSize(size);
453 }
454
SocketDescriptorForTesting() const455 SocketDescriptor TCPClientSocket::SocketDescriptorForTesting() const {
456 return socket_->SocketDescriptorForTesting();
457 }
458
GetTotalReceivedBytes() const459 int64_t TCPClientSocket::GetTotalReceivedBytes() const {
460 return total_received_bytes_;
461 }
462
ApplySocketTag(const SocketTag & tag)463 void TCPClientSocket::ApplySocketTag(const SocketTag& tag) {
464 socket_->ApplySocketTag(tag);
465 }
466
OnSuspend()467 void TCPClientSocket::OnSuspend() {
468 #if defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
469 // If the socket is connected, or connecting, act as if current and future
470 // operations on the socket fail with ERR_NETWORK_IO_SUSPENDED, until the
471 // socket is reconnected.
472
473 if (next_connect_state_ != CONNECT_STATE_NONE) {
474 socket_->Close();
475 DidCompleteConnect(ERR_NETWORK_IO_SUSPENDED);
476 return;
477 }
478
479 // Nothing to do. Use IsValid() rather than IsConnected() because it results
480 // in more testable code, as when calling OnSuspend mode on two sockets
481 // connected to each other will otherwise cause two sockets to behave
482 // differently from each other.
483 if (!socket_->IsValid())
484 return;
485
486 // Use Close() rather than Disconnect() / DoDisconnect() to avoid mutating
487 // state, which more closely matches normal read/write error behavior.
488 socket_->Close();
489
490 was_disconnected_on_suspend_ = true;
491
492 // Grab a weak pointer just in case calling read callback results in |this|
493 // being destroyed, or disconnected. In either case, should not run the write
494 // callback.
495 base::WeakPtr<TCPClientSocket> weak_this = weak_ptr_factory_.GetWeakPtr();
496
497 // Have to grab the write callback now, as it's theoretically possible for the
498 // read callback to reconnects the socket, that reconnection to complete
499 // synchronously, and then for it to start a new write. That also means this
500 // code can't use DidCompleteWrite().
501 CompletionOnceCallback write_callback = std::move(write_callback_);
502 if (read_callback_)
503 DidCompleteRead(ERR_NETWORK_IO_SUSPENDED);
504 if (weak_this && write_callback)
505 std::move(write_callback).Run(ERR_NETWORK_IO_SUSPENDED);
506 #endif // defined(TCP_CLIENT_SOCKET_OBSERVES_SUSPEND)
507 }
508
DidCompleteConnect(int result)509 void TCPClientSocket::DidCompleteConnect(int result) {
510 DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE);
511 DCHECK_NE(result, ERR_IO_PENDING);
512 DCHECK(!connect_callback_.is_null());
513
514 result = DoConnectLoop(result);
515 if (result != ERR_IO_PENDING) {
516 socket_->EndLoggingMultipleConnectAttempts(result);
517 std::move(connect_callback_).Run(result);
518 }
519 }
520
DidCompleteRead(int result)521 void TCPClientSocket::DidCompleteRead(int result) {
522 DCHECK(!read_callback_.is_null());
523
524 if (result > 0)
525 total_received_bytes_ += result;
526 DidCompleteReadWrite(std::move(read_callback_), result);
527 }
528
DidCompleteWrite(int result)529 void TCPClientSocket::DidCompleteWrite(int result) {
530 DCHECK(!write_callback_.is_null());
531
532 DidCompleteReadWrite(std::move(write_callback_), result);
533 }
534
DidCompleteReadWrite(CompletionOnceCallback callback,int result)535 void TCPClientSocket::DidCompleteReadWrite(CompletionOnceCallback callback,
536 int result) {
537 if (result > 0)
538 was_ever_used_ = true;
539 std::move(callback).Run(result);
540 }
541
OpenSocket(AddressFamily family)542 int TCPClientSocket::OpenSocket(AddressFamily family) {
543 DCHECK(!socket_->IsValid());
544
545 int result = socket_->Open(family);
546 if (result != OK)
547 return result;
548
549 if (network_ != handles::kInvalidNetworkHandle) {
550 result = socket_->BindToNetwork(network_);
551 if (result != OK) {
552 socket_->Close();
553 return result;
554 }
555 }
556
557 socket_->SetDefaultOptionsForClient();
558
559 return OK;
560 }
561
EmitConnectAttemptHistograms(int result)562 void TCPClientSocket::EmitConnectAttemptHistograms(int result) {
563 // This should only be called in response to completing a connect attempt.
564 DCHECK(start_connect_attempt_);
565
566 base::TimeDelta duration =
567 base::TimeTicks::Now() - start_connect_attempt_.value();
568
569 // Histogram the total time the connect attempt took, grouped by success and
570 // failure. Note that failures also include cases when the connect attempt
571 // was cancelled by the client before the handshake completed.
572 if (result == OK) {
573 UMA_HISTOGRAM_MEDIUM_TIMES("Net.TcpConnectAttempt.Latency.Success",
574 duration);
575 } else {
576 UMA_HISTOGRAM_MEDIUM_TIMES("Net.TcpConnectAttempt.Latency.Error", duration);
577 }
578 }
579
GetConnectAttemptTimeout()580 base::TimeDelta TCPClientSocket::GetConnectAttemptTimeout() {
581 if (!base::FeatureList::IsEnabled(features::kTimeoutTcpConnectAttempt))
582 return base::TimeDelta::Max();
583
584 std::optional<base::TimeDelta> transport_rtt = std::nullopt;
585 if (network_quality_estimator_)
586 transport_rtt = network_quality_estimator_->GetTransportRTT();
587
588 base::TimeDelta min_timeout = features::kTimeoutTcpConnectAttemptMin.Get();
589 base::TimeDelta max_timeout = features::kTimeoutTcpConnectAttemptMax.Get();
590
591 if (!transport_rtt)
592 return max_timeout;
593
594 base::TimeDelta adaptive_timeout =
595 transport_rtt.value() *
596 features::kTimeoutTcpConnectAttemptRTTMultiplier.Get();
597
598 if (adaptive_timeout <= min_timeout)
599 return min_timeout;
600
601 if (adaptive_timeout >= max_timeout)
602 return max_timeout;
603
604 return adaptive_timeout;
605 }
606
607 } // namespace net
608