xref: /aosp_15_r20/external/cronet/net/socket/websocket_transport_client_socket_pool.h (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
6 #define NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
7 
8 #include <list>
9 #include <map>
10 #include <memory>
11 #include <optional>
12 #include <set>
13 #include <string>
14 #include <utility>
15 
16 #include "base/memory/raw_ptr.h"
17 #include "base/memory/scoped_refptr.h"
18 #include "base/memory/weak_ptr.h"
19 #include "base/timer/timer.h"
20 #include "net/base/net_export.h"
21 #include "net/base/proxy_chain.h"
22 #include "net/log/net_log_with_source.h"
23 #include "net/socket/client_socket_pool.h"
24 #include "net/socket/connect_job.h"
25 #include "net/socket/ssl_client_socket.h"
26 
27 namespace net {
28 
29 struct CommonConnectJobParams;
30 struct NetworkTrafficAnnotationTag;
31 
32 // Identifier for a ClientSocketHandle to scope the lifetime of references.
33 // ClientSocketHandleID are derived from ClientSocketHandle*, used in
34 // comparison only, and are never dereferenced. We use an std::uintptr_t here to
35 // match the size of a pointer, and to prevent dereferencing. Also, our
36 // tooling complains about dangling pointers if we pass around a raw ptr.
37 using ClientSocketHandleID = std::uintptr_t;
38 
39 class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool
40     : public ClientSocketPool {
41  public:
42   WebSocketTransportClientSocketPool(
43       int max_sockets,
44       int max_sockets_per_group,
45       const ProxyChain& proxy_chain,
46       const CommonConnectJobParams* common_connect_job_params);
47 
48   WebSocketTransportClientSocketPool(
49       const WebSocketTransportClientSocketPool&) = delete;
50   WebSocketTransportClientSocketPool& operator=(
51       const WebSocketTransportClientSocketPool&) = delete;
52 
53   ~WebSocketTransportClientSocketPool() override;
54 
55   // Allow another connection to be started to the IPEndPoint that this |handle|
56   // is connected to. Used when the WebSocket handshake completes successfully.
57   // This only works if the socket is connected, however the caller does not
58   // need to explicitly check for this. Instead, ensure that dead sockets are
59   // returned to ReleaseSocket() in a timely fashion.
60   static void UnlockEndpoint(
61       ClientSocketHandle* handle,
62       WebSocketEndpointLockManager* websocket_endpoint_lock_manager);
63 
64   // ClientSocketPool implementation.
65   int RequestSocket(
66       const GroupId& group_id,
67       scoped_refptr<SocketParams> params,
68       const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
69       RequestPriority priority,
70       const SocketTag& socket_tag,
71       RespectLimits respect_limits,
72       ClientSocketHandle* handle,
73       CompletionOnceCallback callback,
74       const ProxyAuthCallback& proxy_auth_callback,
75       const NetLogWithSource& net_log) override;
76   int RequestSockets(
77       const GroupId& group_id,
78       scoped_refptr<SocketParams> params,
79       const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
80       int num_sockets,
81       CompletionOnceCallback callback,
82       const NetLogWithSource& net_log) override;
83   void SetPriority(const GroupId& group_id,
84                    ClientSocketHandle* handle,
85                    RequestPriority priority) override;
86   void CancelRequest(const GroupId& group_id,
87                      ClientSocketHandle* handle,
88                      bool cancel_connect_job) override;
89   void ReleaseSocket(const GroupId& group_id,
90                      std::unique_ptr<StreamSocket> socket,
91                      int64_t generation) override;
92   void FlushWithError(int error, const char* net_log_reason_utf8) override;
93   void CloseIdleSockets(const char* net_log_reason_utf8) override;
94   void CloseIdleSocketsInGroup(const GroupId& group_id,
95                                const char* net_log_reason_utf8) override;
96   int IdleSocketCount() const override;
97   size_t IdleSocketCountInGroup(const GroupId& group_id) const override;
98   LoadState GetLoadState(const GroupId& group_id,
99                          const ClientSocketHandle* handle) const override;
100   base::Value GetInfoAsValue(const std::string& name,
101                              const std::string& type) const override;
102   bool HasActiveSocket(const GroupId& group_id) const override;
103 
104   // HigherLayeredPool implementation.
105   bool IsStalled() const override;
106   void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override;
107   void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override;
108 
109  private:
110   class ConnectJobDelegate : public ConnectJob::Delegate {
111    public:
112     ConnectJobDelegate(WebSocketTransportClientSocketPool* owner,
113                        CompletionOnceCallback callback,
114                        ClientSocketHandle* socket_handle,
115                        const NetLogWithSource& request_net_log);
116 
117     ConnectJobDelegate(const ConnectJobDelegate&) = delete;
118     ConnectJobDelegate& operator=(const ConnectJobDelegate&) = delete;
119 
120     ~ConnectJobDelegate() override;
121 
122     // ConnectJob::Delegate implementation
123     void OnConnectJobComplete(int result, ConnectJob* job) override;
124     void OnNeedsProxyAuth(const HttpResponseInfo& response,
125                           HttpAuthController* auth_controller,
126                           base::OnceClosure restart_with_auth_callback,
127                           ConnectJob* job) override;
128 
129     // Calls Connect() on |connect_job|, and takes ownership. Returns Connect's
130     // return value.
131     int Connect(std::unique_ptr<ConnectJob> connect_job);
132 
release_callback()133     CompletionOnceCallback release_callback() { return std::move(callback_); }
connect_job()134     ConnectJob* connect_job() { return connect_job_.get(); }
socket_handle()135     ClientSocketHandle* socket_handle() { return socket_handle_; }
136 
request_net_log()137     const NetLogWithSource& request_net_log() { return request_net_log_; }
138     const NetLogWithSource& connect_job_net_log();
139 
140    private:
141     raw_ptr<WebSocketTransportClientSocketPool> owner_;
142 
143     CompletionOnceCallback callback_;
144     std::unique_ptr<ConnectJob> connect_job_;
145     const raw_ptr<ClientSocketHandle> socket_handle_;
146     const NetLogWithSource request_net_log_;
147   };
148 
149   // Store the arguments from a call to RequestSocket() that has stalled so we
150   // can replay it when there are available socket slots.
151   struct StalledRequest {
152     StalledRequest(
153         const GroupId& group_id,
154         const scoped_refptr<SocketParams>& params,
155         const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
156         RequestPriority priority,
157         ClientSocketHandle* handle,
158         CompletionOnceCallback callback,
159         const ProxyAuthCallback& proxy_auth_callback,
160         const NetLogWithSource& net_log);
161     StalledRequest(StalledRequest&& other);
162     ~StalledRequest();
163 
164     const GroupId group_id;
165     const scoped_refptr<SocketParams> params;
166     const std::optional<NetworkTrafficAnnotationTag> proxy_annotation_tag;
167     const RequestPriority priority;
168     const raw_ptr<ClientSocketHandle> handle;
169     CompletionOnceCallback callback;
170     ProxyAuthCallback proxy_auth_callback;
171     const NetLogWithSource net_log;
172   };
173 
174   typedef std::map<const ClientSocketHandle*,
175                    std::unique_ptr<ConnectJobDelegate>>
176       PendingConnectsMap;
177   // This is a list so that we can remove requests from the middle, and also
178   // so that iterators are not invalidated unless the corresponding request is
179   // removed.
180   typedef std::list<StalledRequest> StalledRequestQueue;
181   typedef std::map<const ClientSocketHandle*, StalledRequestQueue::iterator>
182       StalledRequestMap;
183 
184   // Tries to hand out the socket connected by |job|. |result| must be (async)
185   // result of TransportConnectJob::Connect(). Returns true iff it has handed
186   // out a socket.
187   bool TryHandOutSocket(int result, ConnectJobDelegate* connect_job_delegate);
188   void OnConnectJobComplete(int result,
189                             ConnectJobDelegate* connect_job_delegate);
190   void InvokeUserCallbackLater(ClientSocketHandle* handle,
191                                CompletionOnceCallback callback,
192                                int rv);
193   void InvokeUserCallback(ClientSocketHandleID handle_id,
194                           base::WeakPtr<ClientSocketHandle> weak_handle,
195                           CompletionOnceCallback callback,
196                           int rv);
197   bool ReachedMaxSocketsLimit() const;
198   void HandOutSocket(std::unique_ptr<StreamSocket> socket,
199                      const LoadTimingInfo::ConnectTiming& connect_timing,
200                      ClientSocketHandle* handle,
201                      const NetLogWithSource& net_log);
202   void AddJob(ClientSocketHandle* handle,
203               std::unique_ptr<ConnectJobDelegate> delegate);
204   bool DeleteJob(ClientSocketHandle* handle);
205   const ConnectJob* LookupConnectJob(const ClientSocketHandle* handle) const;
206   void ActivateStalledRequest();
207   bool DeleteStalledRequest(ClientSocketHandle* handle);
208 
209   const ProxyChain proxy_chain_;
210   std::set<ClientSocketHandleID> pending_callbacks_;
211   PendingConnectsMap pending_connects_;
212   StalledRequestQueue stalled_request_queue_;
213   StalledRequestMap stalled_request_map_;
214   const int max_sockets_;
215   int handed_out_socket_count_ = 0;
216   bool flushing_ = false;
217 
218   base::WeakPtrFactory<WebSocketTransportClientSocketPool> weak_factory_{this};
219 };
220 
221 }  // namespace net
222 
223 #endif  // NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_
224