xref: /aosp_15_r20/external/cronet/net/socket/transport_client_socket_pool_test_util.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/transport_client_socket_pool_test_util.h"
6 
7 #include <stdint.h>
8 #include <string>
9 #include <utility>
10 
11 #include "base/check_op.h"
12 #include "base/functional/bind.h"
13 #include "base/location.h"
14 #include "base/memory/weak_ptr.h"
15 #include "base/notreached.h"
16 #include "base/run_loop.h"
17 #include "base/task/single_thread_task_runner.h"
18 #include "net/base/ip_address.h"
19 #include "net/base/ip_endpoint.h"
20 #include "net/base/load_timing_info.h"
21 #include "net/base/load_timing_info_test_util.h"
22 #include "net/log/net_log_source.h"
23 #include "net/log/net_log_source_type.h"
24 #include "net/log/net_log_with_source.h"
25 #include "net/socket/client_socket_handle.h"
26 #include "net/socket/datagram_client_socket.h"
27 #include "net/socket/ssl_client_socket.h"
28 #include "net/socket/transport_client_socket.h"
29 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
30 #include "testing/gtest/include/gtest/gtest.h"
31 
32 namespace net {
33 
34 namespace {
35 
ParseIP(const std::string & ip)36 IPAddress ParseIP(const std::string& ip) {
37   IPAddress address;
38   CHECK(address.AssignFromIPLiteral(ip));
39   return address;
40 }
41 
42 // A StreamSocket which connects synchronously and successfully.
43 class MockConnectClientSocket : public TransportClientSocket {
44  public:
MockConnectClientSocket(const AddressList & addrlist,net::NetLog * net_log)45   MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log)
46       : addrlist_(addrlist),
47         net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
48 
49   MockConnectClientSocket(const MockConnectClientSocket&) = delete;
50   MockConnectClientSocket& operator=(const MockConnectClientSocket&) = delete;
51 
52   // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)53   int Bind(const net::IPEndPoint& local_addr) override {
54     NOTREACHED();
55     return ERR_FAILED;
56   }
57   // StreamSocket implementation.
Connect(CompletionOnceCallback callback)58   int Connect(CompletionOnceCallback callback) override {
59     connected_ = true;
60     return OK;
61   }
Disconnect()62   void Disconnect() override { connected_ = false; }
IsConnected() const63   bool IsConnected() const override { return connected_; }
IsConnectedAndIdle() const64   bool IsConnectedAndIdle() const override { return connected_; }
65 
GetPeerAddress(IPEndPoint * address) const66   int GetPeerAddress(IPEndPoint* address) const override {
67     *address = addrlist_.front();
68     return OK;
69   }
GetLocalAddress(IPEndPoint * address) const70   int GetLocalAddress(IPEndPoint* address) const override {
71     if (!connected_)
72       return ERR_SOCKET_NOT_CONNECTED;
73     if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
74       SetIPv4Address(address);
75     else
76       SetIPv6Address(address);
77     return OK;
78   }
NetLog() const79   const NetLogWithSource& NetLog() const override { return net_log_; }
80 
WasEverUsed() const81   bool WasEverUsed() const override { return false; }
GetNegotiatedProtocol() const82   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)83   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetTotalReceivedBytes() const84   int64_t GetTotalReceivedBytes() const override {
85     NOTIMPLEMENTED();
86     return 0;
87   }
ApplySocketTag(const SocketTag & tag)88   void ApplySocketTag(const SocketTag& tag) override {}
89 
90   // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)91   int Read(IOBuffer* buf,
92            int buf_len,
93            CompletionOnceCallback callback) override {
94     return ERR_FAILED;
95   }
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)96   int Write(IOBuffer* buf,
97             int buf_len,
98             CompletionOnceCallback callback,
99             const NetworkTrafficAnnotationTag& traffic_annotation) override {
100     return ERR_FAILED;
101   }
SetReceiveBufferSize(int32_t size)102   int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)103   int SetSendBufferSize(int32_t size) override { return OK; }
104 
105  private:
106   bool connected_ = false;
107   const AddressList addrlist_;
108   NetLogWithSource net_log_;
109 };
110 
111 class MockFailingClientSocket : public TransportClientSocket {
112  public:
MockFailingClientSocket(const AddressList & addrlist,Error connect_error,net::NetLog * net_log)113   MockFailingClientSocket(const AddressList& addrlist,
114                           Error connect_error,
115                           net::NetLog* net_log)
116       : addrlist_(addrlist),
117         connect_error_(connect_error),
118         net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
119 
120   MockFailingClientSocket(const MockFailingClientSocket&) = delete;
121   MockFailingClientSocket& operator=(const MockFailingClientSocket&) = delete;
122 
123   // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)124   int Bind(const net::IPEndPoint& local_addr) override {
125     NOTREACHED();
126     return ERR_FAILED;
127   }
128 
129   // StreamSocket implementation.
Connect(CompletionOnceCallback callback)130   int Connect(CompletionOnceCallback callback) override {
131     return connect_error_;
132   }
133 
Disconnect()134   void Disconnect() override {}
135 
IsConnected() const136   bool IsConnected() const override { return false; }
IsConnectedAndIdle() const137   bool IsConnectedAndIdle() const override { return false; }
GetPeerAddress(IPEndPoint * address) const138   int GetPeerAddress(IPEndPoint* address) const override {
139     return ERR_UNEXPECTED;
140   }
GetLocalAddress(IPEndPoint * address) const141   int GetLocalAddress(IPEndPoint* address) const override {
142     return ERR_UNEXPECTED;
143   }
NetLog() const144   const NetLogWithSource& NetLog() const override { return net_log_; }
145 
WasEverUsed() const146   bool WasEverUsed() const override { return false; }
GetNegotiatedProtocol() const147   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)148   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetTotalReceivedBytes() const149   int64_t GetTotalReceivedBytes() const override {
150     NOTIMPLEMENTED();
151     return 0;
152   }
ApplySocketTag(const SocketTag & tag)153   void ApplySocketTag(const SocketTag& tag) override {}
154 
155   // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)156   int Read(IOBuffer* buf,
157            int buf_len,
158            CompletionOnceCallback callback) override {
159     return ERR_FAILED;
160   }
161 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)162   int Write(IOBuffer* buf,
163             int buf_len,
164             CompletionOnceCallback callback,
165             const NetworkTrafficAnnotationTag& traffic_annotation) override {
166     return ERR_FAILED;
167   }
SetReceiveBufferSize(int32_t size)168   int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)169   int SetSendBufferSize(int32_t size) override { return OK; }
170 
171  private:
172   const AddressList addrlist_;
173   const Error connect_error_;
174   NetLogWithSource net_log_;
175 };
176 
177 class MockTriggerableClientSocket : public TransportClientSocket {
178  public:
179   // |connect_error| indicates whether the socket should successfully complete
180   // or fail.
MockTriggerableClientSocket(const AddressList & addrlist,Error connect_error,net::NetLog * net_log)181   MockTriggerableClientSocket(const AddressList& addrlist,
182                               Error connect_error,
183                               net::NetLog* net_log)
184       : connect_error_(connect_error),
185         addrlist_(addrlist),
186         net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
187 
188   MockTriggerableClientSocket(const MockTriggerableClientSocket&) = delete;
189   MockTriggerableClientSocket& operator=(const MockTriggerableClientSocket&) =
190       delete;
191 
192   // Call this method to get a closure which will trigger the connect callback
193   // when called. The closure can be called even after the socket is deleted; it
194   // will safely do nothing.
GetConnectCallback()195   base::OnceClosure GetConnectCallback() {
196     return base::BindOnce(&MockTriggerableClientSocket::DoCallback,
197                           weak_factory_.GetWeakPtr());
198   }
199 
MakeMockPendingClientSocket(const AddressList & addrlist,Error connect_error,net::NetLog * net_log)200   static std::unique_ptr<TransportClientSocket> MakeMockPendingClientSocket(
201       const AddressList& addrlist,
202       Error connect_error,
203       net::NetLog* net_log) {
204     auto socket = std::make_unique<MockTriggerableClientSocket>(
205         addrlist, connect_error, net_log);
206     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
207         FROM_HERE, socket->GetConnectCallback());
208     return std::move(socket);
209   }
210 
MakeMockDelayedClientSocket(const AddressList & addrlist,Error connect_error,const base::TimeDelta & delay,net::NetLog * net_log)211   static std::unique_ptr<TransportClientSocket> MakeMockDelayedClientSocket(
212       const AddressList& addrlist,
213       Error connect_error,
214       const base::TimeDelta& delay,
215       net::NetLog* net_log) {
216     auto socket = std::make_unique<MockTriggerableClientSocket>(
217         addrlist, connect_error, net_log);
218     base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
219         FROM_HERE, socket->GetConnectCallback(), delay);
220     return std::move(socket);
221   }
222 
MakeMockStalledClientSocket(const AddressList & addrlist,net::NetLog * net_log)223   static std::unique_ptr<TransportClientSocket> MakeMockStalledClientSocket(
224       const AddressList& addrlist,
225       net::NetLog* net_log) {
226     // We never post `GetConnectCallback()`, so the value of `connect_error`
227     // does not matter.
228     return std::make_unique<MockTriggerableClientSocket>(
229         addrlist, /*connect_error=*/OK, net_log);
230   }
231 
232   // TransportClientSocket implementation.
Bind(const net::IPEndPoint & local_addr)233   int Bind(const net::IPEndPoint& local_addr) override {
234     NOTREACHED();
235     return ERR_FAILED;
236   }
237 
238   // StreamSocket implementation.
Connect(CompletionOnceCallback callback)239   int Connect(CompletionOnceCallback callback) override {
240     DCHECK(callback_.is_null());
241     callback_ = std::move(callback);
242     return ERR_IO_PENDING;
243   }
244 
Disconnect()245   void Disconnect() override {}
246 
IsConnected() const247   bool IsConnected() const override { return is_connected_; }
IsConnectedAndIdle() const248   bool IsConnectedAndIdle() const override { return is_connected_; }
GetPeerAddress(IPEndPoint * address) const249   int GetPeerAddress(IPEndPoint* address) const override {
250     *address = addrlist_.front();
251     return OK;
252   }
GetLocalAddress(IPEndPoint * address) const253   int GetLocalAddress(IPEndPoint* address) const override {
254     if (!is_connected_)
255       return ERR_SOCKET_NOT_CONNECTED;
256     if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
257       SetIPv4Address(address);
258     else
259       SetIPv6Address(address);
260     return OK;
261   }
NetLog() const262   const NetLogWithSource& NetLog() const override { return net_log_; }
263 
WasEverUsed() const264   bool WasEverUsed() const override { return false; }
GetNegotiatedProtocol() const265   NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
GetSSLInfo(SSLInfo * ssl_info)266   bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
GetTotalReceivedBytes() const267   int64_t GetTotalReceivedBytes() const override {
268     NOTIMPLEMENTED();
269     return 0;
270   }
ApplySocketTag(const SocketTag & tag)271   void ApplySocketTag(const SocketTag& tag) override {}
272 
273   // Socket implementation.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)274   int Read(IOBuffer* buf,
275            int buf_len,
276            CompletionOnceCallback callback) override {
277     return ERR_FAILED;
278   }
279 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)280   int Write(IOBuffer* buf,
281             int buf_len,
282             CompletionOnceCallback callback,
283             const NetworkTrafficAnnotationTag& traffic_annotation) override {
284     return ERR_FAILED;
285   }
SetReceiveBufferSize(int32_t size)286   int SetReceiveBufferSize(int32_t size) override { return OK; }
SetSendBufferSize(int32_t size)287   int SetSendBufferSize(int32_t size) override { return OK; }
288 
289  private:
DoCallback()290   void DoCallback() {
291     is_connected_ = connect_error_ == OK;
292     std::move(callback_).Run(connect_error_);
293   }
294 
295   Error connect_error_;
296   bool is_connected_ = false;
297   const AddressList addrlist_;
298   NetLogWithSource net_log_;
299   CompletionOnceCallback callback_;
300 
301   base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_{this};
302 };
303 
304 }  // namespace
305 
TestLoadTimingInfoConnectedReused(const ClientSocketHandle & handle)306 void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
307   LoadTimingInfo load_timing_info;
308   // Only pass true in as |is_reused|, as in general, HttpStream types should
309   // have stricter concepts of reuse than socket pools.
310   EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
311 
312   EXPECT_TRUE(load_timing_info.socket_reused);
313   EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
314 
315   ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
316   ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
317 }
318 
TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle & handle)319 void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
320   EXPECT_FALSE(handle.is_reused());
321 
322   LoadTimingInfo load_timing_info;
323   EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
324 
325   EXPECT_FALSE(load_timing_info.socket_reused);
326   EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
327 
328   ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
329                               CONNECT_TIMING_HAS_DNS_TIMES);
330   ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
331 
332   TestLoadTimingInfoConnectedReused(handle);
333 }
334 
SetIPv4Address(IPEndPoint * address)335 void SetIPv4Address(IPEndPoint* address) {
336   *address = IPEndPoint(ParseIP("1.1.1.1"), 80);
337 }
338 
SetIPv6Address(IPEndPoint * address)339 void SetIPv6Address(IPEndPoint* address) {
340   *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
341 }
342 
Rule(Type type,std::optional<std::vector<IPEndPoint>> expected_addresses,Error connect_error)343 MockTransportClientSocketFactory::Rule::Rule(
344     Type type,
345     std::optional<std::vector<IPEndPoint>> expected_addresses,
346     Error connect_error)
347     : type(type),
348       expected_addresses(std::move(expected_addresses)),
349       connect_error(connect_error) {}
350 
351 MockTransportClientSocketFactory::Rule::~Rule() = default;
352 
353 MockTransportClientSocketFactory::Rule::Rule(const Rule&) = default;
354 
355 MockTransportClientSocketFactory::Rule&
356 MockTransportClientSocketFactory::Rule::operator=(const Rule&) = default;
357 
MockTransportClientSocketFactory(NetLog * net_log)358 MockTransportClientSocketFactory::MockTransportClientSocketFactory(
359     NetLog* net_log)
360     : net_log_(net_log),
361       delay_(base::Milliseconds(ClientSocketPool::kMaxConnectRetryIntervalMs)) {
362 }
363 
364 MockTransportClientSocketFactory::~MockTransportClientSocketFactory() = default;
365 
366 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)367 MockTransportClientSocketFactory::CreateDatagramClientSocket(
368     DatagramSocket::BindType bind_type,
369     NetLog* net_log,
370     const NetLogSource& source) {
371   NOTREACHED();
372   return nullptr;
373 }
374 
375 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher>,NetworkQualityEstimator *,NetLog *,const NetLogSource &)376 MockTransportClientSocketFactory::CreateTransportClientSocket(
377     const AddressList& addresses,
378     std::unique_ptr<SocketPerformanceWatcher> /* socket_performance_watcher */,
379     NetworkQualityEstimator* /* network_quality_estimator */,
380     NetLog* /* net_log */,
381     const NetLogSource& /* source */) {
382   allocation_count_++;
383 
384   Rule rule(client_socket_type_);
385   if (!rules_.empty()) {
386     rule = rules_.front();
387     rules_ = rules_.subspan(1);
388   }
389 
390   if (rule.expected_addresses) {
391     EXPECT_EQ(addresses.endpoints(), *rule.expected_addresses);
392   }
393 
394   switch (rule.type) {
395     case Type::kUnexpected:
396       ADD_FAILURE() << "Unexpectedly created socket to "
397                     << addresses.endpoints().front();
398       return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
399     case Type::kSynchronous:
400       return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
401     case Type::kFailing:
402       return std::make_unique<MockFailingClientSocket>(
403           addresses, rule.connect_error, net_log_);
404     case Type::kPending:
405       return MockTriggerableClientSocket::MakeMockPendingClientSocket(
406           addresses, OK, net_log_);
407     case Type::kPendingFailing:
408       return MockTriggerableClientSocket::MakeMockPendingClientSocket(
409           addresses, rule.connect_error, net_log_);
410     case Type::kDelayed:
411       return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
412           addresses, OK, delay_, net_log_);
413     case Type::kDelayedFailing:
414       return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
415           addresses, rule.connect_error, delay_, net_log_);
416     case Type::kStalled:
417       return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses,
418                                                                       net_log_);
419     case Type::kTriggerable: {
420       auto rv = std::make_unique<MockTriggerableClientSocket>(addresses, OK,
421                                                               net_log_);
422       triggerable_sockets_.push(rv->GetConnectCallback());
423       // run_loop_quit_closure_ behaves like a condition variable. It will
424       // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
425       // don't need to worry about atomicity because this code is
426       // single-threaded.
427       if (!run_loop_quit_closure_.is_null())
428         std::move(run_loop_quit_closure_).Run();
429       return std::move(rv);
430     }
431     default:
432       NOTREACHED();
433       return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
434   }
435 }
436 
437 std::unique_ptr<SSLClientSocket>
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)438 MockTransportClientSocketFactory::CreateSSLClientSocket(
439     SSLClientContext* context,
440     std::unique_ptr<StreamSocket> stream_socket,
441     const HostPortPair& host_and_port,
442     const SSLConfig& ssl_config) {
443   NOTIMPLEMENTED();
444   return nullptr;
445 }
446 
SetRules(base::span<const Rule> rules)447 void MockTransportClientSocketFactory::SetRules(base::span<const Rule> rules) {
448   DCHECK(rules_.empty());
449   client_socket_type_ = Type::kUnexpected;
450   rules_ = rules;
451 }
452 
453 base::OnceClosure
WaitForTriggerableSocketCreation()454 MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
455   while (triggerable_sockets_.empty()) {
456     base::RunLoop run_loop;
457     run_loop_quit_closure_ = run_loop.QuitClosure();
458     run_loop.Run();
459     run_loop_quit_closure_.Reset();
460   }
461   base::OnceClosure trigger = std::move(triggerable_sockets_.front());
462   triggerable_sockets_.pop();
463   return trigger;
464 }
465 
466 }  // namespace net
467