1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/websocket_endpoint_lock_manager.h"
6
7 #include <memory>
8 #include <utility>
9
10 #include "base/functional/bind.h"
11 #include "base/location.h"
12 #include "base/logging.h"
13 #include "base/task/single_thread_task_runner.h"
14 #include "net/base/net_errors.h"
15
16 namespace net {
17
18 namespace {
19
20 // This delay prevents DoS attacks.
21 // TODO(ricea): Replace this with randomised truncated exponential backoff.
22 // See crbug.com/377613.
23 const int kUnlockDelayInMs = 10;
24
25 } // namespace
26
~Waiter()27 WebSocketEndpointLockManager::Waiter::~Waiter() {
28 if (next()) {
29 DCHECK(previous());
30 RemoveFromList();
31 }
32 }
33
LockReleaser(WebSocketEndpointLockManager * websocket_endpoint_lock_manager,IPEndPoint endpoint)34 WebSocketEndpointLockManager::LockReleaser::LockReleaser(
35 WebSocketEndpointLockManager* websocket_endpoint_lock_manager,
36 IPEndPoint endpoint)
37 : websocket_endpoint_lock_manager_(websocket_endpoint_lock_manager),
38 endpoint_(endpoint) {
39 websocket_endpoint_lock_manager->RegisterLockReleaser(this, endpoint);
40 }
41
~LockReleaser()42 WebSocketEndpointLockManager::LockReleaser::~LockReleaser() {
43 if (websocket_endpoint_lock_manager_) {
44 websocket_endpoint_lock_manager_->UnlockEndpoint(endpoint_);
45 }
46 }
47
WebSocketEndpointLockManager()48 WebSocketEndpointLockManager::WebSocketEndpointLockManager()
49 : unlock_delay_(base::Milliseconds(kUnlockDelayInMs)) {}
50
~WebSocketEndpointLockManager()51 WebSocketEndpointLockManager::~WebSocketEndpointLockManager() {
52 DCHECK_EQ(lock_info_map_.size(), pending_unlock_count_);
53 }
54
LockEndpoint(const IPEndPoint & endpoint,Waiter * waiter)55 int WebSocketEndpointLockManager::LockEndpoint(const IPEndPoint& endpoint,
56 Waiter* waiter) {
57 LockInfoMap::value_type insert_value(endpoint, LockInfo());
58 std::pair<LockInfoMap::iterator, bool> rv =
59 lock_info_map_.insert(insert_value);
60 LockInfo& lock_info_in_map = rv.first->second;
61 if (rv.second) {
62 DVLOG(3) << "Locking endpoint " << endpoint.ToString();
63 lock_info_in_map.queue = std::make_unique<LockInfo::WaiterQueue>();
64 return OK;
65 }
66 DVLOG(3) << "Waiting for endpoint " << endpoint.ToString();
67 lock_info_in_map.queue->Append(waiter);
68 return ERR_IO_PENDING;
69 }
70
UnlockEndpoint(const IPEndPoint & endpoint)71 void WebSocketEndpointLockManager::UnlockEndpoint(const IPEndPoint& endpoint) {
72 auto lock_info_it = lock_info_map_.find(endpoint);
73 if (lock_info_it == lock_info_map_.end())
74 return;
75 LockReleaser* lock_releaser = lock_info_it->second.lock_releaser;
76 if (lock_releaser) {
77 lock_info_it->second.lock_releaser = nullptr;
78 lock_releaser->websocket_endpoint_lock_manager_ = nullptr;
79 }
80 UnlockEndpointAfterDelay(endpoint);
81 }
82
IsEmpty() const83 bool WebSocketEndpointLockManager::IsEmpty() const {
84 return lock_info_map_.empty();
85 }
86
SetUnlockDelayForTesting(base::TimeDelta new_delay)87 base::TimeDelta WebSocketEndpointLockManager::SetUnlockDelayForTesting(
88 base::TimeDelta new_delay) {
89 base::TimeDelta old_delay = unlock_delay_;
90 unlock_delay_ = new_delay;
91 return old_delay;
92 }
93
LockInfo()94 WebSocketEndpointLockManager::LockInfo::LockInfo() : lock_releaser(nullptr) {}
~LockInfo()95 WebSocketEndpointLockManager::LockInfo::~LockInfo() {
96 DCHECK(!lock_releaser);
97 }
98
LockInfo(const LockInfo & rhs)99 WebSocketEndpointLockManager::LockInfo::LockInfo(const LockInfo& rhs)
100 : lock_releaser(rhs.lock_releaser) {
101 DCHECK(!rhs.queue);
102 }
103
RegisterLockReleaser(LockReleaser * lock_releaser,IPEndPoint endpoint)104 void WebSocketEndpointLockManager::RegisterLockReleaser(
105 LockReleaser* lock_releaser,
106 IPEndPoint endpoint) {
107 DCHECK(lock_releaser);
108 auto lock_info_it = lock_info_map_.find(endpoint);
109 CHECK(lock_info_it != lock_info_map_.end());
110 DCHECK(!lock_info_it->second.lock_releaser);
111 lock_info_it->second.lock_releaser = lock_releaser;
112 DVLOG(3) << "Registered (LockReleaser*)" << lock_releaser << " for "
113 << endpoint.ToString();
114 }
115
UnlockEndpointAfterDelay(const IPEndPoint & endpoint)116 void WebSocketEndpointLockManager::UnlockEndpointAfterDelay(
117 const IPEndPoint& endpoint) {
118 DVLOG(3) << "Delaying " << unlock_delay_.InMilliseconds()
119 << "ms before unlocking endpoint " << endpoint.ToString();
120 ++pending_unlock_count_;
121 base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
122 FROM_HERE,
123 base::BindOnce(&WebSocketEndpointLockManager::DelayedUnlockEndpoint,
124 weak_factory_.GetWeakPtr(), endpoint),
125 unlock_delay_);
126 }
127
DelayedUnlockEndpoint(const IPEndPoint & endpoint)128 void WebSocketEndpointLockManager::DelayedUnlockEndpoint(
129 const IPEndPoint& endpoint) {
130 auto lock_info_it = lock_info_map_.find(endpoint);
131 DCHECK_GT(pending_unlock_count_, 0U);
132 --pending_unlock_count_;
133 if (lock_info_it == lock_info_map_.end())
134 return;
135 DCHECK(!lock_info_it->second.lock_releaser);
136 LockInfo::WaiterQueue* queue = lock_info_it->second.queue.get();
137 DCHECK(queue);
138 if (queue->empty()) {
139 DVLOG(3) << "Unlocking endpoint " << lock_info_it->first.ToString();
140 lock_info_map_.erase(lock_info_it);
141 return;
142 }
143
144 DVLOG(3) << "Unlocking endpoint " << lock_info_it->first.ToString()
145 << " and activating next waiter";
146 Waiter* next_job = queue->head()->value();
147 next_job->RemoveFromList();
148 next_job->GotEndpointLock();
149 }
150
151 } // namespace net
152