1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/transport_client_socket_pool_test_util.h"
6
7 #include <stdint.h>
8 #include <string>
9 #include <utility>
10
11 #include "base/check_op.h"
12 #include "base/functional/bind.h"
13 #include "base/location.h"
14 #include "base/memory/weak_ptr.h"
15 #include "base/notreached.h"
16 #include "base/run_loop.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "net/base/ip_address.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/load_timing_info.h"
21 #include "net/base/load_timing_info_test_util.h"
22 #include "net/log/net_log_source.h"
23 #include "net/log/net_log_source_type.h"
24 #include "net/log/net_log_with_source.h"
25 #include "net/socket/client_socket_handle.h"
26 #include "net/socket/datagram_client_socket.h"
27 #include "net/socket/ssl_client_socket.h"
28 #include "net/socket/transport_client_socket.h"
29 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
30 #include "testing/gtest/include/gtest/gtest.h"
31
32 namespace net {
33
34 namespace {
35
ParseIP(const std::string & ip)36 IPAddress ParseIP(const std::string& ip) {
37 IPAddress address;
38 CHECK(address.AssignFromIPLiteral(ip));
39 return address;
40 }
41
42 // A StreamSocket which connects synchronously and successfully.
43 class MockConnectClientSocket : public TransportClientSocket {
44 public:
MockConnectClientSocket(const AddressList & addrlist,net::NetLog * net_log)45 MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log)
46 : addrlist_(addrlist),
47 net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
48
49 MockConnectClientSocket(const MockConnectClientSocket&) = delete;
50 MockConnectClientSocket& operator=(const MockConnectClientSocket&) = delete;
51
52 // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)53 int Bind(const net::IPEndPoint& local_addr) override {
54 NOTREACHED();
55 return ERR_FAILED;
56 }
57 // StreamSocket implementation.
Connect(CompletionOnceCallback callback)58 int Connect(CompletionOnceCallback callback) override {
59 connected_ = true;
60 return OK;
61 }
Disconnect()62 void Disconnect() override { connected_ = false; }
IsConnected() const63 bool IsConnected() const override { return connected_; }
IsConnectedAndIdle() const64 bool IsConnectedAndIdle() const override { return connected_; }
65
GetPeerAddress(IPEndPoint * address) const66 int GetPeerAddress(IPEndPoint* address) const override {
67 *address = addrlist_.front();
68 return OK;
69 }
GetLocalAddress(IPEndPoint * address) const70 int GetLocalAddress(IPEndPoint* address) const override {
71 if (!connected_)
72 return ERR_SOCKET_NOT_CONNECTED;
73 if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
74 SetIPv4Address(address);
75 else
76 SetIPv6Address(address);
77 return OK;
78 }
NetLog() const79 const NetLogWithSource& NetLog() const override { return net_log_; }
80
WasEverUsed() const81 bool WasEverUsed() const override { return false; }
GetNegotiatedProtocol() const82 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)83 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetTotalReceivedBytes() const84 int64_t GetTotalReceivedBytes() const override {
85 NOTIMPLEMENTED();
86 return 0;
87 }
ApplySocketTag(const SocketTag & tag)88 void ApplySocketTag(const SocketTag& tag) override {}
89
90 // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)91 int Read(IOBuffer* buf,
92 int buf_len,
93 CompletionOnceCallback callback) override {
94 return ERR_FAILED;
95 }
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)96 int Write(IOBuffer* buf,
97 int buf_len,
98 CompletionOnceCallback callback,
99 const NetworkTrafficAnnotationTag& traffic_annotation) override {
100 return ERR_FAILED;
101 }
SetReceiveBufferSize(int32_t size)102 int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)103 int SetSendBufferSize(int32_t size) override { return OK; }
104
105 private:
106 bool connected_ = false;
107 const AddressList addrlist_;
108 NetLogWithSource net_log_;
109 };
110
111 class MockFailingClientSocket : public TransportClientSocket {
112 public:
MockFailingClientSocket(const AddressList & addrlist,Error connect_error,net::NetLog * net_log)113 MockFailingClientSocket(const AddressList& addrlist,
114 Error connect_error,
115 net::NetLog* net_log)
116 : addrlist_(addrlist),
117 connect_error_(connect_error),
118 net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
119
120 MockFailingClientSocket(const MockFailingClientSocket&) = delete;
121 MockFailingClientSocket& operator=(const MockFailingClientSocket&) = delete;
122
123 // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)124 int Bind(const net::IPEndPoint& local_addr) override {
125 NOTREACHED();
126 return ERR_FAILED;
127 }
128
129 // StreamSocket implementation.
Connect(CompletionOnceCallback callback)130 int Connect(CompletionOnceCallback callback) override {
131 return connect_error_;
132 }
133
Disconnect()134 void Disconnect() override {}
135
IsConnected() const136 bool IsConnected() const override { return false; }
IsConnectedAndIdle() const137 bool IsConnectedAndIdle() const override { return false; }
GetPeerAddress(IPEndPoint * address) const138 int GetPeerAddress(IPEndPoint* address) const override {
139 return ERR_UNEXPECTED;
140 }
GetLocalAddress(IPEndPoint * address) const141 int GetLocalAddress(IPEndPoint* address) const override {
142 return ERR_UNEXPECTED;
143 }
NetLog() const144 const NetLogWithSource& NetLog() const override { return net_log_; }
145
WasEverUsed() const146 bool WasEverUsed() const override { return false; }
GetNegotiatedProtocol() const147 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)148 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetTotalReceivedBytes() const149 int64_t GetTotalReceivedBytes() const override {
150 NOTIMPLEMENTED();
151 return 0;
152 }
ApplySocketTag(const SocketTag & tag)153 void ApplySocketTag(const SocketTag& tag) override {}
154
155 // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)156 int Read(IOBuffer* buf,
157 int buf_len,
158 CompletionOnceCallback callback) override {
159 return ERR_FAILED;
160 }
161
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)162 int Write(IOBuffer* buf,
163 int buf_len,
164 CompletionOnceCallback callback,
165 const NetworkTrafficAnnotationTag& traffic_annotation) override {
166 return ERR_FAILED;
167 }
SetReceiveBufferSize(int32_t size)168 int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)169 int SetSendBufferSize(int32_t size) override { return OK; }
170
171 private:
172 const AddressList addrlist_;
173 const Error connect_error_;
174 NetLogWithSource net_log_;
175 };
176
177 class MockTriggerableClientSocket : public TransportClientSocket {
178 public:
179 // |connect_error| indicates whether the socket should successfully complete
180 // or fail.
MockTriggerableClientSocket(const AddressList & addrlist,Error connect_error,net::NetLog * net_log)181 MockTriggerableClientSocket(const AddressList& addrlist,
182 Error connect_error,
183 net::NetLog* net_log)
184 : connect_error_(connect_error),
185 addrlist_(addrlist),
186 net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
187
188 MockTriggerableClientSocket(const MockTriggerableClientSocket&) = delete;
189 MockTriggerableClientSocket& operator=(const MockTriggerableClientSocket&) =
190 delete;
191
192 // Call this method to get a closure which will trigger the connect callback
193 // when called. The closure can be called even after the socket is deleted; it
194 // will safely do nothing.
GetConnectCallback()195 base::OnceClosure GetConnectCallback() {
196 return base::BindOnce(&MockTriggerableClientSocket::DoCallback,
197 weak_factory_.GetWeakPtr());
198 }
199
MakeMockPendingClientSocket(const AddressList & addrlist,Error connect_error,net::NetLog * net_log)200 static std::unique_ptr<TransportClientSocket> MakeMockPendingClientSocket(
201 const AddressList& addrlist,
202 Error connect_error,
203 net::NetLog* net_log) {
204 auto socket = std::make_unique<MockTriggerableClientSocket>(
205 addrlist, connect_error, net_log);
206 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
207 FROM_HERE, socket->GetConnectCallback());
208 return std::move(socket);
209 }
210
MakeMockDelayedClientSocket(const AddressList & addrlist,Error connect_error,const base::TimeDelta & delay,net::NetLog * net_log)211 static std::unique_ptr<TransportClientSocket> MakeMockDelayedClientSocket(
212 const AddressList& addrlist,
213 Error connect_error,
214 const base::TimeDelta& delay,
215 net::NetLog* net_log) {
216 auto socket = std::make_unique<MockTriggerableClientSocket>(
217 addrlist, connect_error, net_log);
218 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
219 FROM_HERE, socket->GetConnectCallback(), delay);
220 return std::move(socket);
221 }
222
MakeMockStalledClientSocket(const AddressList & addrlist,net::NetLog * net_log)223 static std::unique_ptr<TransportClientSocket> MakeMockStalledClientSocket(
224 const AddressList& addrlist,
225 net::NetLog* net_log) {
226 // We never post `GetConnectCallback()`, so the value of `connect_error`
227 // does not matter.
228 return std::make_unique<MockTriggerableClientSocket>(
229 addrlist, /*connect_error=*/OK, net_log);
230 }
231
232 // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)233 int Bind(const net::IPEndPoint& local_addr) override {
234 NOTREACHED();
235 return ERR_FAILED;
236 }
237
238 // StreamSocket implementation.
Connect(CompletionOnceCallback callback)239 int Connect(CompletionOnceCallback callback) override {
240 DCHECK(callback_.is_null());
241 callback_ = std::move(callback);
242 return ERR_IO_PENDING;
243 }
244
Disconnect()245 void Disconnect() override {}
246
IsConnected() const247 bool IsConnected() const override { return is_connected_; }
IsConnectedAndIdle() const248 bool IsConnectedAndIdle() const override { return is_connected_; }
GetPeerAddress(IPEndPoint * address) const249 int GetPeerAddress(IPEndPoint* address) const override {
250 *address = addrlist_.front();
251 return OK;
252 }
GetLocalAddress(IPEndPoint * address) const253 int GetLocalAddress(IPEndPoint* address) const override {
254 if (!is_connected_)
255 return ERR_SOCKET_NOT_CONNECTED;
256 if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
257 SetIPv4Address(address);
258 else
259 SetIPv6Address(address);
260 return OK;
261 }
NetLog() const262 const NetLogWithSource& NetLog() const override { return net_log_; }
263
WasEverUsed() const264 bool WasEverUsed() const override { return false; }
GetNegotiatedProtocol() const265 NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)266 bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetTotalReceivedBytes() const267 int64_t GetTotalReceivedBytes() const override {
268 NOTIMPLEMENTED();
269 return 0;
270 }
ApplySocketTag(const SocketTag & tag)271 void ApplySocketTag(const SocketTag& tag) override {}
272
273 // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)274 int Read(IOBuffer* buf,
275 int buf_len,
276 CompletionOnceCallback callback) override {
277 return ERR_FAILED;
278 }
279
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)280 int Write(IOBuffer* buf,
281 int buf_len,
282 CompletionOnceCallback callback,
283 const NetworkTrafficAnnotationTag& traffic_annotation) override {
284 return ERR_FAILED;
285 }
SetReceiveBufferSize(int32_t size)286 int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)287 int SetSendBufferSize(int32_t size) override { return OK; }
288
289 private:
DoCallback()290 void DoCallback() {
291 is_connected_ = connect_error_ == OK;
292 std::move(callback_).Run(connect_error_);
293 }
294
295 Error connect_error_;
296 bool is_connected_ = false;
297 const AddressList addrlist_;
298 NetLogWithSource net_log_;
299 CompletionOnceCallback callback_;
300
301 base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_{this};
302 };
303
304 } // namespace
305
TestLoadTimingInfoConnectedReused(const ClientSocketHandle & handle)306 void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
307 LoadTimingInfo load_timing_info;
308 // Only pass true in as |is_reused|, as in general, HttpStream types should
309 // have stricter concepts of reuse than socket pools.
310 EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
311
312 EXPECT_TRUE(load_timing_info.socket_reused);
313 EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
314
315 ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
316 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
317 }
318
TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle & handle)319 void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
320 EXPECT_FALSE(handle.is_reused());
321
322 LoadTimingInfo load_timing_info;
323 EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
324
325 EXPECT_FALSE(load_timing_info.socket_reused);
326 EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
327
328 ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
329 CONNECT_TIMING_HAS_DNS_TIMES);
330 ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
331
332 TestLoadTimingInfoConnectedReused(handle);
333 }
334
SetIPv4Address(IPEndPoint * address)335 void SetIPv4Address(IPEndPoint* address) {
336 *address = IPEndPoint(ParseIP("1.1.1.1"), 80);
337 }
338
SetIPv6Address(IPEndPoint * address)339 void SetIPv6Address(IPEndPoint* address) {
340 *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
341 }
342
Rule(Type type,std::optional<std::vector<IPEndPoint>> expected_addresses,Error connect_error)343 MockTransportClientSocketFactory::Rule::Rule(
344 Type type,
345 std::optional<std::vector<IPEndPoint>> expected_addresses,
346 Error connect_error)
347 : type(type),
348 expected_addresses(std::move(expected_addresses)),
349 connect_error(connect_error) {}
350
351 MockTransportClientSocketFactory::Rule::~Rule() = default;
352
353 MockTransportClientSocketFactory::Rule::Rule(const Rule&) = default;
354
355 MockTransportClientSocketFactory::Rule&
356 MockTransportClientSocketFactory::Rule::operator=(const Rule&) = default;
357
MockTransportClientSocketFactory(NetLog * net_log)358 MockTransportClientSocketFactory::MockTransportClientSocketFactory(
359 NetLog* net_log)
360 : net_log_(net_log),
361 delay_(base::Milliseconds(ClientSocketPool::kMaxConnectRetryIntervalMs)) {
362 }
363
364 MockTransportClientSocketFactory::~MockTransportClientSocketFactory() = default;
365
366 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)367 MockTransportClientSocketFactory::CreateDatagramClientSocket(
368 DatagramSocket::BindType bind_type,
369 NetLog* net_log,
370 const NetLogSource& source) {
371 NOTREACHED();
372 return nullptr;
373 }
374
375 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher>,NetworkQualityEstimator *,NetLog *,const NetLogSource &)376 MockTransportClientSocketFactory::CreateTransportClientSocket(
377 const AddressList& addresses,
378 std::unique_ptr<SocketPerformanceWatcher> /* socket_performance_watcher */,
379 NetworkQualityEstimator* /* network_quality_estimator */,
380 NetLog* /* net_log */,
381 const NetLogSource& /* source */) {
382 allocation_count_++;
383
384 Rule rule(client_socket_type_);
385 if (!rules_.empty()) {
386 rule = rules_.front();
387 rules_ = rules_.subspan(1);
388 }
389
390 if (rule.expected_addresses) {
391 EXPECT_EQ(addresses.endpoints(), *rule.expected_addresses);
392 }
393
394 switch (rule.type) {
395 case Type::kUnexpected:
396 ADD_FAILURE() << "Unexpectedly created socket to "
397 << addresses.endpoints().front();
398 return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
399 case Type::kSynchronous:
400 return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
401 case Type::kFailing:
402 return std::make_unique<MockFailingClientSocket>(
403 addresses, rule.connect_error, net_log_);
404 case Type::kPending:
405 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
406 addresses, OK, net_log_);
407 case Type::kPendingFailing:
408 return MockTriggerableClientSocket::MakeMockPendingClientSocket(
409 addresses, rule.connect_error, net_log_);
410 case Type::kDelayed:
411 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
412 addresses, OK, delay_, net_log_);
413 case Type::kDelayedFailing:
414 return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
415 addresses, rule.connect_error, delay_, net_log_);
416 case Type::kStalled:
417 return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses,
418 net_log_);
419 case Type::kTriggerable: {
420 auto rv = std::make_unique<MockTriggerableClientSocket>(addresses, OK,
421 net_log_);
422 triggerable_sockets_.push(rv->GetConnectCallback());
423 // run_loop_quit_closure_ behaves like a condition variable. It will
424 // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
425 // don't need to worry about atomicity because this code is
426 // single-threaded.
427 if (!run_loop_quit_closure_.is_null())
428 std::move(run_loop_quit_closure_).Run();
429 return std::move(rv);
430 }
431 default:
432 NOTREACHED();
433 return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
434 }
435 }
436
437 std::unique_ptr<SSLClientSocket>
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)438 MockTransportClientSocketFactory::CreateSSLClientSocket(
439 SSLClientContext* context,
440 std::unique_ptr<StreamSocket> stream_socket,
441 const HostPortPair& host_and_port,
442 const SSLConfig& ssl_config) {
443 NOTIMPLEMENTED();
444 return nullptr;
445 }
446
SetRules(base::span<const Rule> rules)447 void MockTransportClientSocketFactory::SetRules(base::span<const Rule> rules) {
448 DCHECK(rules_.empty());
449 client_socket_type_ = Type::kUnexpected;
450 rules_ = rules;
451 }
452
453 base::OnceClosure
WaitForTriggerableSocketCreation()454 MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
455 while (triggerable_sockets_.empty()) {
456 base::RunLoop run_loop;
457 run_loop_quit_closure_ = run_loop.QuitClosure();
458 run_loop.Run();
459 run_loop_quit_closure_.Reset();
460 }
461 base::OnceClosure trigger = std::move(triggerable_sockets_.front());
462 triggerable_sockets_.pop();
463 return trigger;
464 }
465
466 } // namespace net
467