1 // Copyright 2014 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ 6 #define NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ 7 8 #include <list> 9 #include <map> 10 #include <memory> 11 #include <optional> 12 #include <set> 13 #include <string> 14 #include <utility> 15 16 #include "base/memory/raw_ptr.h" 17 #include "base/memory/scoped_refptr.h" 18 #include "base/memory/weak_ptr.h" 19 #include "base/timer/timer.h" 20 #include "net/base/net_export.h" 21 #include "net/base/proxy_chain.h" 22 #include "net/log/net_log_with_source.h" 23 #include "net/socket/client_socket_pool.h" 24 #include "net/socket/connect_job.h" 25 #include "net/socket/ssl_client_socket.h" 26 27 namespace net { 28 29 struct CommonConnectJobParams; 30 struct NetworkTrafficAnnotationTag; 31 32 // Identifier for a ClientSocketHandle to scope the lifetime of references. 33 // ClientSocketHandleID are derived from ClientSocketHandle*, used in 34 // comparison only, and are never dereferenced. We use an std::uintptr_t here to 35 // match the size of a pointer, and to prevent dereferencing. Also, our 36 // tooling complains about dangling pointers if we pass around a raw ptr. 37 using ClientSocketHandleID = std::uintptr_t; 38 39 class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool 40 : public ClientSocketPool { 41 public: 42 WebSocketTransportClientSocketPool( 43 int max_sockets, 44 int max_sockets_per_group, 45 const ProxyChain& proxy_chain, 46 const CommonConnectJobParams* common_connect_job_params); 47 48 WebSocketTransportClientSocketPool( 49 const WebSocketTransportClientSocketPool&) = delete; 50 WebSocketTransportClientSocketPool& operator=( 51 const WebSocketTransportClientSocketPool&) = delete; 52 53 ~WebSocketTransportClientSocketPool() override; 54 55 // Allow another connection to be started to the IPEndPoint that this |handle| 56 // is connected to. Used when the WebSocket handshake completes successfully. 57 // This only works if the socket is connected, however the caller does not 58 // need to explicitly check for this. Instead, ensure that dead sockets are 59 // returned to ReleaseSocket() in a timely fashion. 60 static void UnlockEndpoint( 61 ClientSocketHandle* handle, 62 WebSocketEndpointLockManager* websocket_endpoint_lock_manager); 63 64 // ClientSocketPool implementation. 65 int RequestSocket( 66 const GroupId& group_id, 67 scoped_refptr<SocketParams> params, 68 const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 69 RequestPriority priority, 70 const SocketTag& socket_tag, 71 RespectLimits respect_limits, 72 ClientSocketHandle* handle, 73 CompletionOnceCallback callback, 74 const ProxyAuthCallback& proxy_auth_callback, 75 const NetLogWithSource& net_log) override; 76 int RequestSockets( 77 const GroupId& group_id, 78 scoped_refptr<SocketParams> params, 79 const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 80 int num_sockets, 81 CompletionOnceCallback callback, 82 const NetLogWithSource& net_log) override; 83 void SetPriority(const GroupId& group_id, 84 ClientSocketHandle* handle, 85 RequestPriority priority) override; 86 void CancelRequest(const GroupId& group_id, 87 ClientSocketHandle* handle, 88 bool cancel_connect_job) override; 89 void ReleaseSocket(const GroupId& group_id, 90 std::unique_ptr<StreamSocket> socket, 91 int64_t generation) override; 92 void FlushWithError(int error, const char* net_log_reason_utf8) override; 93 void CloseIdleSockets(const char* net_log_reason_utf8) override; 94 void CloseIdleSocketsInGroup(const GroupId& group_id, 95 const char* net_log_reason_utf8) override; 96 int IdleSocketCount() const override; 97 size_t IdleSocketCountInGroup(const GroupId& group_id) const override; 98 LoadState GetLoadState(const GroupId& group_id, 99 const ClientSocketHandle* handle) const override; 100 base::Value GetInfoAsValue(const std::string& name, 101 const std::string& type) const override; 102 bool HasActiveSocket(const GroupId& group_id) const override; 103 104 // HigherLayeredPool implementation. 105 bool IsStalled() const override; 106 void AddHigherLayeredPool(HigherLayeredPool* higher_pool) override; 107 void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) override; 108 109 private: 110 class ConnectJobDelegate : public ConnectJob::Delegate { 111 public: 112 ConnectJobDelegate(WebSocketTransportClientSocketPool* owner, 113 CompletionOnceCallback callback, 114 ClientSocketHandle* socket_handle, 115 const NetLogWithSource& request_net_log); 116 117 ConnectJobDelegate(const ConnectJobDelegate&) = delete; 118 ConnectJobDelegate& operator=(const ConnectJobDelegate&) = delete; 119 120 ~ConnectJobDelegate() override; 121 122 // ConnectJob::Delegate implementation 123 void OnConnectJobComplete(int result, ConnectJob* job) override; 124 void OnNeedsProxyAuth(const HttpResponseInfo& response, 125 HttpAuthController* auth_controller, 126 base::OnceClosure restart_with_auth_callback, 127 ConnectJob* job) override; 128 129 // Calls Connect() on |connect_job|, and takes ownership. Returns Connect's 130 // return value. 131 int Connect(std::unique_ptr<ConnectJob> connect_job); 132 release_callback()133 CompletionOnceCallback release_callback() { return std::move(callback_); } connect_job()134 ConnectJob* connect_job() { return connect_job_.get(); } socket_handle()135 ClientSocketHandle* socket_handle() { return socket_handle_; } 136 request_net_log()137 const NetLogWithSource& request_net_log() { return request_net_log_; } 138 const NetLogWithSource& connect_job_net_log(); 139 140 private: 141 raw_ptr<WebSocketTransportClientSocketPool> owner_; 142 143 CompletionOnceCallback callback_; 144 std::unique_ptr<ConnectJob> connect_job_; 145 const raw_ptr<ClientSocketHandle> socket_handle_; 146 const NetLogWithSource request_net_log_; 147 }; 148 149 // Store the arguments from a call to RequestSocket() that has stalled so we 150 // can replay it when there are available socket slots. 151 struct StalledRequest { 152 StalledRequest( 153 const GroupId& group_id, 154 const scoped_refptr<SocketParams>& params, 155 const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 156 RequestPriority priority, 157 ClientSocketHandle* handle, 158 CompletionOnceCallback callback, 159 const ProxyAuthCallback& proxy_auth_callback, 160 const NetLogWithSource& net_log); 161 StalledRequest(StalledRequest&& other); 162 ~StalledRequest(); 163 164 const GroupId group_id; 165 const scoped_refptr<SocketParams> params; 166 const std::optional<NetworkTrafficAnnotationTag> proxy_annotation_tag; 167 const RequestPriority priority; 168 const raw_ptr<ClientSocketHandle> handle; 169 CompletionOnceCallback callback; 170 ProxyAuthCallback proxy_auth_callback; 171 const NetLogWithSource net_log; 172 }; 173 174 typedef std::map<const ClientSocketHandle*, 175 std::unique_ptr<ConnectJobDelegate>> 176 PendingConnectsMap; 177 // This is a list so that we can remove requests from the middle, and also 178 // so that iterators are not invalidated unless the corresponding request is 179 // removed. 180 typedef std::list<StalledRequest> StalledRequestQueue; 181 typedef std::map<const ClientSocketHandle*, StalledRequestQueue::iterator> 182 StalledRequestMap; 183 184 // Tries to hand out the socket connected by |job|. |result| must be (async) 185 // result of TransportConnectJob::Connect(). Returns true iff it has handed 186 // out a socket. 187 bool TryHandOutSocket(int result, ConnectJobDelegate* connect_job_delegate); 188 void OnConnectJobComplete(int result, 189 ConnectJobDelegate* connect_job_delegate); 190 void InvokeUserCallbackLater(ClientSocketHandle* handle, 191 CompletionOnceCallback callback, 192 int rv); 193 void InvokeUserCallback(ClientSocketHandleID handle_id, 194 base::WeakPtr<ClientSocketHandle> weak_handle, 195 CompletionOnceCallback callback, 196 int rv); 197 bool ReachedMaxSocketsLimit() const; 198 void HandOutSocket(std::unique_ptr<StreamSocket> socket, 199 const LoadTimingInfo::ConnectTiming& connect_timing, 200 ClientSocketHandle* handle, 201 const NetLogWithSource& net_log); 202 void AddJob(ClientSocketHandle* handle, 203 std::unique_ptr<ConnectJobDelegate> delegate); 204 bool DeleteJob(ClientSocketHandle* handle); 205 const ConnectJob* LookupConnectJob(const ClientSocketHandle* handle) const; 206 void ActivateStalledRequest(); 207 bool DeleteStalledRequest(ClientSocketHandle* handle); 208 209 const ProxyChain proxy_chain_; 210 std::set<ClientSocketHandleID> pending_callbacks_; 211 PendingConnectsMap pending_connects_; 212 StalledRequestQueue stalled_request_queue_; 213 StalledRequestMap stalled_request_map_; 214 const int max_sockets_; 215 int handed_out_socket_count_ = 0; 216 bool flushing_ = false; 217 218 base::WeakPtrFactory<WebSocketTransportClientSocketPool> weak_factory_{this}; 219 }; 220 221 } // namespace net 222 223 #endif // NET_SOCKET_WEBSOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ 224