xref: /aosp_15_r20/system/chre/host/common/socket_client.cc (revision 84e339476a462649f82315436d70fd732297a399)
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre_host/socket_client.h"
18 
19 #include <inttypes.h>
20 
21 #include <string.h>
22 #include <unistd.h>
23 
24 #include <chrono>
25 
26 #include <cutils/sockets.h>
27 #include <sys/epoll.h>
28 #include <sys/socket.h>
29 #include <utils/RefBase.h>
30 #include <utils/StrongPointer.h>
31 
32 #include "chre_host/log.h"
33 
34 namespace android {
35 namespace chre {
36 
SocketClient()37 SocketClient::SocketClient() {
38   std::atomic_init(&mSockFd, INVALID_SOCKET);
39 }
40 
~SocketClient()41 SocketClient::~SocketClient() {
42   disconnect();
43 }
44 
connect(const char * socketName,const sp<ICallbacks> & callbacks)45 bool SocketClient::connect(const char *socketName,
46                            const sp<ICallbacks> &callbacks) {
47   return doConnect(socketName, callbacks, false /* connectInBackground */);
48 }
49 
connectInBackground(const char * socketName,const sp<ICallbacks> & callbacks)50 bool SocketClient::connectInBackground(const char *socketName,
51                                        const sp<ICallbacks> &callbacks) {
52   return doConnect(socketName, callbacks, true /* connectInBackground */);
53 }
54 
disconnect()55 void SocketClient::disconnect() {
56   if (inReceiveThread()) {
57     LOGE("disconnect() can't be called from a receive thread callback");
58   } else if (receiveThreadRunning()) {
59     // Inform the RX thread that we're requesting a shutdown, breaking it out of
60     // the retry wait if it's currently blocked there
61     {
62       std::lock_guard<std::mutex> lock(mShutdownMutex);
63       mGracefulShutdown = true;
64     }
65     mShutdownCond.notify_all();
66 
67     // Invalidate the socket (will kick the RX thread out of recv if it's
68     // currently blocked there)
69     if (mSockFd != INVALID_SOCKET && shutdown(mSockFd, SHUT_RDWR) != 0) {
70       LOG_ERROR("Couldn't shut down socket", errno);
71     }
72 
73     if (mRxThread.joinable()) {
74       LOGD("Waiting for RX thread to exit");
75       mRxThread.join();
76     }
77   }
78 }
79 
isConnected() const80 bool SocketClient::isConnected() const {
81   return (mSockFd != INVALID_SOCKET);
82 }
83 
sendMessage(const void * data,size_t length)84 bool SocketClient::sendMessage(const void *data, size_t length) {
85   bool success = false;
86 
87   if (mSockFd == INVALID_SOCKET) {
88     LOGW("Tried sending a message, but don't have a valid socket handle");
89   } else {
90     ssize_t bytesSent = send(mSockFd, data, length, 0);
91     if (bytesSent < 0) {
92       LOGE("Failed to send %zu bytes of data: %s", length, strerror(errno));
93     } else if (bytesSent == 0) {
94       LOGW("Failed to send data; remote side disconnected");
95     } else if (static_cast<size_t>(bytesSent) != length) {
96       LOGW("Truncated packet, tried sending %zu bytes, only %zd went through",
97            length, bytesSent);
98     } else {
99       success = true;
100     }
101   }
102 
103   return success;
104 }
105 
doConnect(const char * socketName,const sp<ICallbacks> & callbacks,bool connectInBackground)106 bool SocketClient::doConnect(const char *socketName,
107                              const sp<ICallbacks> &callbacks,
108                              bool connectInBackground) {
109   bool success = false;
110   if (inReceiveThread()) {
111     LOGE("Can't attempt to connect from a receive thread callback");
112   } else {
113     if (receiveThreadRunning()) {
114       LOGW("Re-connecting socket with implicit disconnect");
115       disconnect();
116     }
117 
118     size_t socketNameLen =
119         strlcpy(mSocketName, socketName, sizeof(mSocketName));
120     if (socketNameLen >= sizeof(mSocketName)) {
121       LOGE("Socket name length parameter is too long (%zu, max %zu)",
122            socketNameLen, sizeof(mSocketName));
123     } else if (callbacks == nullptr) {
124       LOGE("Callbacks parameter must be provided");
125     } else if (connectInBackground || tryConnect()) {
126       mGracefulShutdown = false;
127       mCallbacks = callbacks;
128       mRxThread = std::thread([this]() { receiveThread(); });
129       success = true;
130     }
131   }
132 
133   return success;
134 }
135 
inReceiveThread() const136 bool SocketClient::inReceiveThread() const {
137   return (std::this_thread::get_id() == mRxThread.get_id());
138 }
139 
receiveThread()140 void SocketClient::receiveThread() {
141   LOGV("Receive thread started");
142   while (!mGracefulShutdown && (mSockFd != INVALID_SOCKET || reconnect())) {
143     struct epoll_event requestedEvent;
144     requestedEvent.data.fd = mSockFd;
145     requestedEvent.events = EPOLLIN | EPOLLWAKEUP;
146 
147     int epollFd = TEMP_FAILURE_RETRY(epoll_create1(0));
148     if (epollFd < 0) {
149       LOG_ERROR("Error creating epoll fd", errno);
150       break;
151     }
152 
153     if (TEMP_FAILURE_RETRY(epoll_ctl(epollFd, EPOLL_CTL_ADD,
154                                      requestedEvent.data.fd, &requestedEvent)) <
155         0) {
156       LOG_ERROR("Error adding socket fd to epoll", errno);
157       close(epollFd);
158       break;
159     }
160 
161     while (!mGracefulShutdown) {
162       struct epoll_event returnedEvent;
163       // Blockingly wait for the next epoll event. The implicit wakelock will be
164       // held until the next call to epoll_wait on the same epoll file
165       // descriptor
166       int eventsReady = TEMP_FAILURE_RETRY(epoll_wait(epollFd, &returnedEvent,
167                                                       /* event_count= */ 1,
168                                                       /* timeout_ms= */ -1));
169       if (eventsReady < 0) {
170         LOG_ERROR("Poll error", errno);
171         break;
172       }
173 
174       ssize_t bytesReceived =
175           recv(mSockFd, mRecvBuffer.data(), mRecvBuffer.size(), 0);
176 
177       if (bytesReceived < 0) {
178         LOG_ERROR("Exiting RX thread", errno);
179         if (!mGracefulShutdown) {
180           LOGI("Force onDisconnected");
181           mCallbacks->onDisconnected();
182         }
183         break;
184       } else if (bytesReceived == 0) {
185         if (!mGracefulShutdown) {
186           LOGI("Socket disconnected on remote end");
187           mCallbacks->onDisconnected();
188         }
189         break;
190       }
191 
192       mCallbacks->onMessageReceived(mRecvBuffer.data(), bytesReceived);
193     }
194 
195     if (close(mSockFd) != 0) {
196       LOG_ERROR("Couldn't close socket", errno);
197     }
198     mSockFd = INVALID_SOCKET;
199     close(epollFd);
200   }
201 
202   if (!mGracefulShutdown) {
203     mCallbacks->onConnectionAborted();
204   }
205 
206   mCallbacks.clear();
207   LOGV("Exiting receive thread");
208 }
209 
receiveThreadRunning() const210 bool SocketClient::receiveThreadRunning() const {
211   return mRxThread.joinable();
212 }
213 
reconnect()214 bool SocketClient::reconnect() {
215   constexpr auto kMinDelay = std::chrono::duration<int32_t, std::milli>(250);
216   constexpr auto kMaxDelay = std::chrono::minutes(5);
217   // Try reconnecting at initial delay this many times before backing off
218   constexpr unsigned int kExponentialBackoffDelay =
219       std::chrono::seconds(10) / kMinDelay;
220   // Give up after this many tries (~2.5 hours)
221   constexpr unsigned int kRetryLimit = kExponentialBackoffDelay + 40;
222   auto delay = kMinDelay;
223   unsigned int retryCount = 0;
224 
225   while (retryCount++ < kRetryLimit) {
226     {
227       std::unique_lock<std::mutex> lock(mShutdownMutex);
228       mShutdownCond.wait_for(lock, delay,
229                              [this]() { return mGracefulShutdown.load(); });
230       if (mGracefulShutdown) {
231         break;
232       }
233     }
234 
235     bool suppressErrorLogs = (delay == kMinDelay);
236     if (!tryConnect(suppressErrorLogs)) {
237       if (!suppressErrorLogs) {
238         LOGW("Failed to (re)connect, next try in %" PRId32 " ms",
239              delay.count());
240       }
241       if (retryCount > kExponentialBackoffDelay) {
242         delay *= 2;
243       }
244       if (delay > kMaxDelay) {
245         delay = kMaxDelay;
246       }
247     } else {
248       LOGD("Successfully (re)connected");
249       mCallbacks->onConnected();
250       return true;
251     }
252   }
253 
254   return false;
255 }
256 
tryConnect(bool suppressErrorLogs)257 bool SocketClient::tryConnect(bool suppressErrorLogs) {
258   bool success = false;
259 
260   errno = 0;
261   int sockFd = socket(AF_LOCAL, SOCK_SEQPACKET, 0);
262   if (sockFd >= 0) {
263     // Set the send buffer size to 2MB to allow plenty of room for nanoapp
264     // loading
265     int sndbuf = 2 * 1024 * 1024;
266     // Normally, send() should effectively return immediately, but in the event
267     // that we get blocked due to flow control, don't stay blocked for more than
268     // 3 seconds
269     struct timeval timeout = {
270         .tv_sec = 3,
271         .tv_usec = 0,
272     };
273     int ret;
274 
275     if ((ret = setsockopt(sockFd, SOL_SOCKET, SO_SNDBUF, &sndbuf,
276                           sizeof(sndbuf))) != 0) {
277       if (!suppressErrorLogs) {
278         LOGE("Failed to set SO_SNDBUF to %d: %s", sndbuf, strerror(errno));
279       }
280     } else if ((ret = setsockopt(sockFd, SOL_SOCKET, SO_SNDTIMEO, &timeout,
281                                  sizeof(timeout))) != 0) {
282       if (!suppressErrorLogs) {
283         LOGE("Failed to set SO_SNDTIMEO: %s", strerror(errno));
284       }
285     } else {
286       mSockFd = socket_local_client_connect(sockFd, mSocketName,
287                                             ANDROID_SOCKET_NAMESPACE_RESERVED,
288                                             SOCK_SEQPACKET);
289       if (mSockFd != INVALID_SOCKET) {
290         success = true;
291       } else if (!suppressErrorLogs) {
292         LOGE("Couldn't connect client socket to '%s': %s", mSocketName,
293              strerror(errno));
294       }
295     }
296 
297     if (!success) {
298       close(sockFd);
299     }
300   } else if (!suppressErrorLogs) {
301     LOGE("Couldn't create local socket: %s", strerror(errno));
302   }
303 
304   return success;
305 }
306 
307 }  // namespace chre
308 }  // namespace android
309