1 /*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/traced_relay/socket_relay_handler.h"
18
19 #include <fcntl.h>
20 #include <sys/poll.h>
21 #include <algorithm>
22 #include <memory>
23 #include <mutex>
24 #include <thread>
25 #include <utility>
26
27 #include "perfetto/base/logging.h"
28 #include "perfetto/base/platform_handle.h"
29 #include "perfetto/ext/base/thread_checker.h"
30 #include "perfetto/ext/base/utils.h"
31 #include "perfetto/ext/base/watchdog.h"
32
33 namespace perfetto {
34 namespace {
35 // Use the default watchdog timeout for task runners.
36 static constexpr int kWatchdogTimeoutMs = 30000;
37 // Timeout of the epoll_wait() call.
38 static constexpr int kPollTimeoutMs = 30000;
39 } // namespace
40
41 FdPoller::Watcher::~Watcher() = default;
42
FdPoller(Watcher * watcher)43 FdPoller::FdPoller(Watcher* watcher) : watcher_(watcher) {
44 WatchForRead(notify_fd_.fd());
45
46 // This is done last in the ctor because WatchForRead() asserts using
47 // |thread_checker_|.
48 PERFETTO_DETACH_FROM_THREAD(thread_checker_);
49 }
50
Poll()51 void FdPoller::Poll() {
52 PERFETTO_DCHECK_THREAD(thread_checker_);
53
54 int num_fds = PERFETTO_EINTR(poll(
55 &poll_fds_[0], static_cast<nfds_t>(poll_fds_.size()), kPollTimeoutMs));
56 if (num_fds == -1 && base::IsAgain(errno))
57 return; // Poll again.
58 PERFETTO_DCHECK(num_fds <= static_cast<int>(poll_fds_.size()));
59
60 // Make a copy of |poll_fds_| so it's safe to watch and unwatch while
61 // notifying the watcher.
62 const auto poll_fds(poll_fds_);
63
64 for (const auto& event : poll_fds) {
65 if (!event.revents) // This event isn't active.
66 continue;
67
68 // Check whether the poller needs to break the polling loop for updates.
69 if (event.fd == notify_fd_.fd()) {
70 notify_fd_.Clear();
71 continue;
72 }
73
74 // Notify the callers on fd events.
75 if (event.revents & POLLOUT) {
76 watcher_->OnFdWritable(event.fd);
77 } else if (event.revents & POLLIN) {
78 watcher_->OnFdReadable(event.fd);
79 } else {
80 PERFETTO_DLOG("poll() returns events %d on fd %d", event.events,
81 event.fd);
82 } // Other events like POLLHUP or POLLERR are ignored.
83 }
84 }
85
Notify()86 void FdPoller::Notify() {
87 // Can be called from any thread.
88 notify_fd_.Notify();
89 }
90
FindPollEvent(base::PlatformHandle fd)91 std::vector<pollfd>::iterator FdPoller::FindPollEvent(base::PlatformHandle fd) {
92 PERFETTO_DCHECK_THREAD(thread_checker_);
93
94 return std::find_if(poll_fds_.begin(), poll_fds_.end(),
95 [fd](const pollfd& item) { return fd == item.fd; });
96 }
97
WatchFd(base::PlatformHandle fd,WatchEvents events)98 void FdPoller::WatchFd(base::PlatformHandle fd, WatchEvents events) {
99 auto it = FindPollEvent(fd);
100 if (it == poll_fds_.end()) {
101 poll_fds_.push_back({fd, events, 0});
102 } else {
103 it->events |= events;
104 }
105 }
106
UnwatchFd(base::PlatformHandle fd,WatchEvents events)107 void FdPoller::UnwatchFd(base::PlatformHandle fd, WatchEvents events) {
108 auto it = FindPollEvent(fd);
109 PERFETTO_CHECK(it != poll_fds_.end());
110 it->events &= ~events;
111 }
112
RemoveWatch(base::PlatformHandle fd)113 void FdPoller::RemoveWatch(base::PlatformHandle fd) {
114 auto it = FindPollEvent(fd);
115 PERFETTO_CHECK(it != poll_fds_.end());
116 poll_fds_.erase(it);
117 }
118
SocketRelayHandler()119 SocketRelayHandler::SocketRelayHandler() : fd_poller_(this) {
120 PERFETTO_DETACH_FROM_THREAD(io_thread_checker_);
121
122 io_thread_ = std::thread([this]() { this->Run(); });
123 }
124
~SocketRelayHandler()125 SocketRelayHandler::~SocketRelayHandler() {
126 RunOnIOThread([this]() { this->exited_ = true; });
127 io_thread_.join();
128 }
129
AddSocketPair(std::unique_ptr<SocketPair> socket_pair)130 void SocketRelayHandler::AddSocketPair(
131 std::unique_ptr<SocketPair> socket_pair) {
132 RunOnIOThread([this, socket_pair = std::move(socket_pair)]() mutable {
133 PERFETTO_DCHECK_THREAD(io_thread_checker_);
134
135 base::PlatformHandle fd1 = socket_pair->first.sock.fd();
136 base::PlatformHandle fd2 = socket_pair->second.sock.fd();
137 auto* ptr = socket_pair.get();
138 socket_pairs_.emplace_back(std::move(socket_pair));
139
140 fd_poller_.WatchForRead(fd1);
141 fd_poller_.WatchForRead(fd2);
142
143 socket_pairs_by_fd_[fd1] = ptr;
144 socket_pairs_by_fd_[fd2] = ptr;
145 });
146 }
147
Run()148 void SocketRelayHandler::Run() {
149 PERFETTO_DCHECK_THREAD(io_thread_checker_);
150
151 while (!exited_) {
152 fd_poller_.Poll();
153
154 auto handle = base::Watchdog::GetInstance()->CreateFatalTimer(
155 kWatchdogTimeoutMs, base::WatchdogCrashReason::kTaskRunnerHung);
156
157 std::deque<std::packaged_task<void()>> pending_tasks;
158 {
159 std::lock_guard<std::mutex> lock(mutex_);
160 pending_tasks = std::move(pending_tasks_);
161 }
162 while (!pending_tasks.empty()) {
163 auto task = std::move(pending_tasks.front());
164 pending_tasks.pop_front();
165 task();
166 }
167 }
168 }
169
OnFdReadable(base::PlatformHandle fd)170 void SocketRelayHandler::OnFdReadable(base::PlatformHandle fd) {
171 PERFETTO_DCHECK_THREAD(io_thread_checker_);
172
173 auto socket_pair = GetSocketPair(fd);
174 if (!socket_pair)
175 return; // Already removed.
176
177 auto [fd_sock, peer_sock] = *socket_pair;
178 // Buffer some bytes.
179 auto peer_fd = peer_sock.sock.fd();
180 while (fd_sock.available_bytes() > 0) {
181 auto rsize =
182 fd_sock.sock.Receive(fd_sock.buffer(), fd_sock.available_bytes());
183 if (rsize > 0) {
184 fd_sock.EnqueueData(static_cast<size_t>(rsize));
185 continue;
186 }
187
188 if (rsize == 0 || (rsize == -1 && !base::IsAgain(errno))) {
189 // TODO(chinglinyu): flush the remaining data to |peer_sock|.
190 RemoveSocketPair(fd_sock, peer_sock);
191 return;
192 }
193
194 // If there is any buffered data that needs to be sent to |peer_sock|, arm
195 // the write watcher.
196 if (fd_sock.data_size() > 0) {
197 fd_poller_.WatchForWrite(peer_fd);
198 }
199 return;
200 }
201 // We are not bufferable: need to turn off POLLIN to avoid spinning.
202 fd_poller_.UnwatchForRead(fd);
203 PERFETTO_DCHECK(fd_sock.data_size() > 0);
204 // Watching for POLLOUT will cause an OnFdWritable() event of
205 // |peer_sock|.
206 fd_poller_.WatchForWrite(peer_fd);
207 }
208
OnFdWritable(base::PlatformHandle fd)209 void SocketRelayHandler::OnFdWritable(base::PlatformHandle fd) {
210 PERFETTO_DCHECK_THREAD(io_thread_checker_);
211
212 auto socket_pair = GetSocketPair(fd);
213 if (!socket_pair)
214 return; // Already removed.
215
216 auto [fd_sock, peer_sock] = *socket_pair;
217 // |fd_sock| can be written to without blocking. Now we can transfer from the
218 // buffer in |peer_sock|.
219 while (peer_sock.data_size() > 0) {
220 auto wsize = fd_sock.sock.Send(peer_sock.data(), peer_sock.data_size());
221 if (wsize > 0) {
222 peer_sock.DequeueData(static_cast<size_t>(wsize));
223 continue;
224 }
225
226 if (wsize == -1 && !base::IsAgain(errno)) {
227 RemoveSocketPair(fd_sock, peer_sock);
228 }
229 // errno == EAGAIN and we still have data to send: continue watching for
230 // read.
231 return;
232 }
233
234 // We don't have buffered data to send. Disable watching for write.
235 fd_poller_.UnwatchForWrite(fd);
236 auto peer_fd = peer_sock.sock.fd();
237 if (peer_sock.available_bytes())
238 fd_poller_.WatchForRead(peer_fd);
239 }
240
241 std::optional<std::tuple<SocketWithBuffer&, SocketWithBuffer&>>
GetSocketPair(base::PlatformHandle fd)242 SocketRelayHandler::GetSocketPair(base::PlatformHandle fd) {
243 PERFETTO_DCHECK_THREAD(io_thread_checker_);
244
245 auto* socket_pair = socket_pairs_by_fd_.Find(fd);
246 if (!socket_pair)
247 return std::nullopt;
248
249 PERFETTO_DCHECK(fd == (*socket_pair)->first.sock.fd() ||
250 fd == (*socket_pair)->second.sock.fd());
251
252 if (fd == (*socket_pair)->first.sock.fd())
253 return std::tie((*socket_pair)->first, (*socket_pair)->second);
254
255 return std::tie((*socket_pair)->second, (*socket_pair)->first);
256 }
257
RemoveSocketPair(SocketWithBuffer & sock1,SocketWithBuffer & sock2)258 void SocketRelayHandler::RemoveSocketPair(SocketWithBuffer& sock1,
259 SocketWithBuffer& sock2) {
260 PERFETTO_DCHECK_THREAD(io_thread_checker_);
261
262 auto fd1 = sock1.sock.fd();
263 auto fd2 = sock2.sock.fd();
264 fd_poller_.RemoveWatch(fd1);
265 fd_poller_.RemoveWatch(fd2);
266
267 auto* ptr1 = socket_pairs_by_fd_.Find(fd1);
268 auto* ptr2 = socket_pairs_by_fd_.Find(fd2);
269 PERFETTO_DCHECK(ptr1 && ptr2);
270 PERFETTO_DCHECK(*ptr1 == *ptr2);
271
272 auto* socket_pair_ptr = *ptr1;
273
274 socket_pairs_by_fd_.Erase(fd1);
275 socket_pairs_by_fd_.Erase(fd2);
276
277 socket_pairs_.erase(
278 std::remove_if(
279 socket_pairs_.begin(), socket_pairs_.end(),
280 [socket_pair_ptr](const std::unique_ptr<SocketPair>& item) {
281 return item.get() == socket_pair_ptr;
282 }),
283 socket_pairs_.end());
284 }
285
286 } // namespace perfetto
287