xref: /aosp_15_r20/external/perfetto/src/traced_relay/socket_relay_handler.cc (revision 6dbdd20afdafa5e3ca9b8809fa73465d530080dc)
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/traced_relay/socket_relay_handler.h"
18 
19 #include <fcntl.h>
20 #include <sys/poll.h>
21 #include <algorithm>
22 #include <memory>
23 #include <mutex>
24 #include <thread>
25 #include <utility>
26 
27 #include "perfetto/base/logging.h"
28 #include "perfetto/base/platform_handle.h"
29 #include "perfetto/ext/base/thread_checker.h"
30 #include "perfetto/ext/base/utils.h"
31 #include "perfetto/ext/base/watchdog.h"
32 
33 namespace perfetto {
34 namespace {
35 // Use the default watchdog timeout for task runners.
36 static constexpr int kWatchdogTimeoutMs = 30000;
37 // Timeout of the epoll_wait() call.
38 static constexpr int kPollTimeoutMs = 30000;
39 }  // namespace
40 
41 FdPoller::Watcher::~Watcher() = default;
42 
FdPoller(Watcher * watcher)43 FdPoller::FdPoller(Watcher* watcher) : watcher_(watcher) {
44   WatchForRead(notify_fd_.fd());
45 
46   // This is done last in the ctor because WatchForRead() asserts using
47   // |thread_checker_|.
48   PERFETTO_DETACH_FROM_THREAD(thread_checker_);
49 }
50 
Poll()51 void FdPoller::Poll() {
52   PERFETTO_DCHECK_THREAD(thread_checker_);
53 
54   int num_fds = PERFETTO_EINTR(poll(
55       &poll_fds_[0], static_cast<nfds_t>(poll_fds_.size()), kPollTimeoutMs));
56   if (num_fds == -1 && base::IsAgain(errno))
57     return;  // Poll again.
58   PERFETTO_DCHECK(num_fds <= static_cast<int>(poll_fds_.size()));
59 
60   // Make a copy of |poll_fds_| so it's safe to watch and unwatch while
61   // notifying the watcher.
62   const auto poll_fds(poll_fds_);
63 
64   for (const auto& event : poll_fds) {
65     if (!event.revents)  // This event isn't active.
66       continue;
67 
68     // Check whether the poller needs to break the polling loop for updates.
69     if (event.fd == notify_fd_.fd()) {
70       notify_fd_.Clear();
71       continue;
72     }
73 
74     // Notify the callers on fd events.
75     if (event.revents & POLLOUT) {
76       watcher_->OnFdWritable(event.fd);
77     } else if (event.revents & POLLIN) {
78       watcher_->OnFdReadable(event.fd);
79     } else {
80       PERFETTO_DLOG("poll() returns events %d on fd %d", event.events,
81                     event.fd);
82     }  // Other events like POLLHUP or POLLERR are ignored.
83   }
84 }
85 
Notify()86 void FdPoller::Notify() {
87   // Can be called from any thread.
88   notify_fd_.Notify();
89 }
90 
FindPollEvent(base::PlatformHandle fd)91 std::vector<pollfd>::iterator FdPoller::FindPollEvent(base::PlatformHandle fd) {
92   PERFETTO_DCHECK_THREAD(thread_checker_);
93 
94   return std::find_if(poll_fds_.begin(), poll_fds_.end(),
95                       [fd](const pollfd& item) { return fd == item.fd; });
96 }
97 
WatchFd(base::PlatformHandle fd,WatchEvents events)98 void FdPoller::WatchFd(base::PlatformHandle fd, WatchEvents events) {
99   auto it = FindPollEvent(fd);
100   if (it == poll_fds_.end()) {
101     poll_fds_.push_back({fd, events, 0});
102   } else {
103     it->events |= events;
104   }
105 }
106 
UnwatchFd(base::PlatformHandle fd,WatchEvents events)107 void FdPoller::UnwatchFd(base::PlatformHandle fd, WatchEvents events) {
108   auto it = FindPollEvent(fd);
109   PERFETTO_CHECK(it != poll_fds_.end());
110   it->events &= ~events;
111 }
112 
RemoveWatch(base::PlatformHandle fd)113 void FdPoller::RemoveWatch(base::PlatformHandle fd) {
114   auto it = FindPollEvent(fd);
115   PERFETTO_CHECK(it != poll_fds_.end());
116   poll_fds_.erase(it);
117 }
118 
SocketRelayHandler()119 SocketRelayHandler::SocketRelayHandler() : fd_poller_(this) {
120   PERFETTO_DETACH_FROM_THREAD(io_thread_checker_);
121 
122   io_thread_ = std::thread([this]() { this->Run(); });
123 }
124 
~SocketRelayHandler()125 SocketRelayHandler::~SocketRelayHandler() {
126   RunOnIOThread([this]() { this->exited_ = true; });
127   io_thread_.join();
128 }
129 
AddSocketPair(std::unique_ptr<SocketPair> socket_pair)130 void SocketRelayHandler::AddSocketPair(
131     std::unique_ptr<SocketPair> socket_pair) {
132   RunOnIOThread([this, socket_pair = std::move(socket_pair)]() mutable {
133     PERFETTO_DCHECK_THREAD(io_thread_checker_);
134 
135     base::PlatformHandle fd1 = socket_pair->first.sock.fd();
136     base::PlatformHandle fd2 = socket_pair->second.sock.fd();
137     auto* ptr = socket_pair.get();
138     socket_pairs_.emplace_back(std::move(socket_pair));
139 
140     fd_poller_.WatchForRead(fd1);
141     fd_poller_.WatchForRead(fd2);
142 
143     socket_pairs_by_fd_[fd1] = ptr;
144     socket_pairs_by_fd_[fd2] = ptr;
145   });
146 }
147 
Run()148 void SocketRelayHandler::Run() {
149   PERFETTO_DCHECK_THREAD(io_thread_checker_);
150 
151   while (!exited_) {
152     fd_poller_.Poll();
153 
154     auto handle = base::Watchdog::GetInstance()->CreateFatalTimer(
155         kWatchdogTimeoutMs, base::WatchdogCrashReason::kTaskRunnerHung);
156 
157     std::deque<std::packaged_task<void()>> pending_tasks;
158     {
159       std::lock_guard<std::mutex> lock(mutex_);
160       pending_tasks = std::move(pending_tasks_);
161     }
162     while (!pending_tasks.empty()) {
163       auto task = std::move(pending_tasks.front());
164       pending_tasks.pop_front();
165       task();
166     }
167   }
168 }
169 
OnFdReadable(base::PlatformHandle fd)170 void SocketRelayHandler::OnFdReadable(base::PlatformHandle fd) {
171   PERFETTO_DCHECK_THREAD(io_thread_checker_);
172 
173   auto socket_pair = GetSocketPair(fd);
174   if (!socket_pair)
175     return;  // Already removed.
176 
177   auto [fd_sock, peer_sock] = *socket_pair;
178   // Buffer some bytes.
179   auto peer_fd = peer_sock.sock.fd();
180   while (fd_sock.available_bytes() > 0) {
181     auto rsize =
182         fd_sock.sock.Receive(fd_sock.buffer(), fd_sock.available_bytes());
183     if (rsize > 0) {
184       fd_sock.EnqueueData(static_cast<size_t>(rsize));
185       continue;
186     }
187 
188     if (rsize == 0 || (rsize == -1 && !base::IsAgain(errno))) {
189       // TODO(chinglinyu): flush the remaining data to |peer_sock|.
190       RemoveSocketPair(fd_sock, peer_sock);
191       return;
192     }
193 
194     // If there is any buffered data that needs to be sent to |peer_sock|, arm
195     // the write watcher.
196     if (fd_sock.data_size() > 0) {
197       fd_poller_.WatchForWrite(peer_fd);
198     }
199     return;
200   }
201   // We are not bufferable: need to turn off POLLIN to avoid spinning.
202   fd_poller_.UnwatchForRead(fd);
203   PERFETTO_DCHECK(fd_sock.data_size() > 0);
204   // Watching for POLLOUT will cause an OnFdWritable() event of
205   // |peer_sock|.
206   fd_poller_.WatchForWrite(peer_fd);
207 }
208 
OnFdWritable(base::PlatformHandle fd)209 void SocketRelayHandler::OnFdWritable(base::PlatformHandle fd) {
210   PERFETTO_DCHECK_THREAD(io_thread_checker_);
211 
212   auto socket_pair = GetSocketPair(fd);
213   if (!socket_pair)
214     return;  // Already removed.
215 
216   auto [fd_sock, peer_sock] = *socket_pair;
217   // |fd_sock| can be written to without blocking. Now we can transfer from the
218   // buffer in |peer_sock|.
219   while (peer_sock.data_size() > 0) {
220     auto wsize = fd_sock.sock.Send(peer_sock.data(), peer_sock.data_size());
221     if (wsize > 0) {
222       peer_sock.DequeueData(static_cast<size_t>(wsize));
223       continue;
224     }
225 
226     if (wsize == -1 && !base::IsAgain(errno)) {
227       RemoveSocketPair(fd_sock, peer_sock);
228     }
229     // errno == EAGAIN and we still have data to send: continue watching for
230     // read.
231     return;
232   }
233 
234   // We don't have buffered data to send. Disable watching for write.
235   fd_poller_.UnwatchForWrite(fd);
236   auto peer_fd = peer_sock.sock.fd();
237   if (peer_sock.available_bytes())
238     fd_poller_.WatchForRead(peer_fd);
239 }
240 
241 std::optional<std::tuple<SocketWithBuffer&, SocketWithBuffer&>>
GetSocketPair(base::PlatformHandle fd)242 SocketRelayHandler::GetSocketPair(base::PlatformHandle fd) {
243   PERFETTO_DCHECK_THREAD(io_thread_checker_);
244 
245   auto* socket_pair = socket_pairs_by_fd_.Find(fd);
246   if (!socket_pair)
247     return std::nullopt;
248 
249   PERFETTO_DCHECK(fd == (*socket_pair)->first.sock.fd() ||
250                   fd == (*socket_pair)->second.sock.fd());
251 
252   if (fd == (*socket_pair)->first.sock.fd())
253     return std::tie((*socket_pair)->first, (*socket_pair)->second);
254 
255   return std::tie((*socket_pair)->second, (*socket_pair)->first);
256 }
257 
RemoveSocketPair(SocketWithBuffer & sock1,SocketWithBuffer & sock2)258 void SocketRelayHandler::RemoveSocketPair(SocketWithBuffer& sock1,
259                                           SocketWithBuffer& sock2) {
260   PERFETTO_DCHECK_THREAD(io_thread_checker_);
261 
262   auto fd1 = sock1.sock.fd();
263   auto fd2 = sock2.sock.fd();
264   fd_poller_.RemoveWatch(fd1);
265   fd_poller_.RemoveWatch(fd2);
266 
267   auto* ptr1 = socket_pairs_by_fd_.Find(fd1);
268   auto* ptr2 = socket_pairs_by_fd_.Find(fd2);
269   PERFETTO_DCHECK(ptr1 && ptr2);
270   PERFETTO_DCHECK(*ptr1 == *ptr2);
271 
272   auto* socket_pair_ptr = *ptr1;
273 
274   socket_pairs_by_fd_.Erase(fd1);
275   socket_pairs_by_fd_.Erase(fd2);
276 
277   socket_pairs_.erase(
278       std::remove_if(
279           socket_pairs_.begin(), socket_pairs_.end(),
280           [socket_pair_ptr](const std::unique_ptr<SocketPair>& item) {
281             return item.get() == socket_pair_ptr;
282           }),
283       socket_pairs_.end());
284 }
285 
286 }  // namespace perfetto
287