1 /*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "message_hub_manager.h"
18
19 #include <unistd.h>
20
21 #include <cstdint>
22 #include <functional>
23 #include <list>
24 #include <memory>
25 #include <optional>
26 #include <string>
27 #include <unordered_map>
28 #include <utility>
29 #include <vector>
30
31 #include <aidl/android/hardware/contexthub/BnContextHub.h>
32
33 #include "chre_host/log.h"
34 #include "pw_result/result.h"
35 #include "pw_status/status.h"
36 #include "pw_status/try.h"
37
38 namespace android::hardware::contexthub::common::implementation {
39
40 using HostHub = MessageHubManager::HostHub;
41
~HostHub()42 HostHub::~HostHub() {
43 std::lock_guard lock(mManager.mLock);
44 unlinkCallbackIfNecessaryLocked();
45 }
46
setCallback(std::shared_ptr<IEndpointCallback> callback)47 pw::Status HostHub::setCallback(std::shared_ptr<IEndpointCallback> callback) {
48 std::lock_guard lock(mManager.mLock);
49 auto *cookie = new DeathRecipientCookie{&mManager, kPid};
50 if (AIBinder_linkToDeath(callback->asBinder().get(),
51 mManager.mDeathRecipient.get(),
52 cookie) != STATUS_OK) {
53 LOGE("Failed to link callback for hub %ld (pid %d) to death recipient", kId,
54 kPid);
55 delete cookie;
56 return pw::Status::Internal();
57 }
58 unlinkCallbackIfNecessaryLocked();
59 mCallback = std::move(callback);
60 mCookie = cookie;
61 return pw::OkStatus();
62 }
63
getCallback() const64 std::shared_ptr<IEndpointCallback> HostHub::getCallback() const {
65 std::lock_guard lock(mManager.mLock);
66 return mCallback;
67 }
68
addEndpoint(std::weak_ptr<HostHub> self,const EndpointInfo & info)69 pw::Status HostHub::addEndpoint(std::weak_ptr<HostHub> self,
70 const EndpointInfo &info) {
71 std::lock_guard lock(mManager.mLock);
72 PW_TRY(checkValidLocked());
73 int64_t id = info.id.id;
74 if (auto it = mIdToEndpoint.find(id); it != mIdToEndpoint.end()) {
75 LOGE("Endpoint %ld already exists in hub %ld (pid %d)", id, kId, kPid);
76 return pw::Status::AlreadyExists();
77 }
78 if (kId == kHubIdInvalid) {
79 // If this is the hub's first endpoint, store its hub id and register it
80 // with the manager.
81 if (info.id.hubId == kContextHubServiceHubId &&
82 AIBinder_getCallingUid() != kSystemServerUid) {
83 LOGW(
84 "Non-systemserver client (pid %d) trying to register as "
85 "ContextHubService",
86 kPid);
87 return pw::Status::InvalidArgument();
88 }
89 kId = info.id.hubId;
90 mManager.mIdToHostHub.insert({kId, self});
91 }
92 mIdToEndpoint.insert({id, std::make_unique<EndpointInfo>(info)});
93 return pw::OkStatus();
94 }
95
removeEndpoint(const EndpointId & id)96 pw::Status HostHub::removeEndpoint(const EndpointId &id) {
97 std::lock_guard lock(mManager.mLock);
98 PW_TRY(checkValidLocked());
99 if (auto it = mIdToEndpoint.find(id.id); it != mIdToEndpoint.end()) {
100 mIdToEndpoint.erase(it);
101 return pw::OkStatus();
102 }
103 LOGE("Client (hub %ld, pid %d) tried to remove unknown endpoint %ld", kId,
104 kPid, id.id);
105 return pw::Status::NotFound();
106 }
107
reserveSessionIdRange(uint16_t size)108 pw::Result<std::pair<uint16_t, uint16_t>> HostHub::reserveSessionIdRange(
109 uint16_t size) {
110 std::lock_guard lock(mManager.mLock);
111 PW_TRY(checkValidLocked());
112 if (!size || size > kSessionIdMaxRange) {
113 LOGE("Client (hub %ld, pid %d) tried to allocate %hu session ids", kId,
114 kPid, size);
115 return pw::Status::InvalidArgument();
116 }
117 if (USHRT_MAX - mManager.mNextSessionId + 1 < size) {
118 LOGW("Could not allocate %hu session ids, ids exhausted", size);
119 return pw::Status::ResourceExhausted();
120 }
121 mSessionIdRanges.push_back(
122 {mManager.mNextSessionId, mManager.mNextSessionId + size - 1});
123 mManager.mNextSessionId += size;
124 return mSessionIdRanges.back();
125 }
126
openSession(std::weak_ptr<HostHub> self,const EndpointId & localId,const EndpointId & remoteId,uint16_t sessionId)127 pw::Result<std::shared_ptr<HostHub>> HostHub::openSession(
128 std::weak_ptr<HostHub> self, const EndpointId &localId,
129 const EndpointId &remoteId, uint16_t sessionId) {
130 std::lock_guard lock(mManager.mLock);
131 PW_TRY(checkValidLocked());
132
133 // Lookup the endpoints.
134 PW_TRY_ASSIGN(std::shared_ptr<EndpointInfo> local,
135 getEndpointLocked(localId));
136 PW_TRY_ASSIGN(std::shared_ptr<EndpointInfo> remote,
137 mManager.getEmbeddedEndpointLocked(remoteId));
138
139 // Validate the session id.
140 bool hostInitiated = AIBinder_isHandlingTransaction();
141 if (hostInitiated) {
142 if (!sessionIdInRangeLocked(sessionId)) {
143 LOGE("Session id %hu out of range for hub %ld", sessionId, kId);
144 return pw::Status::OutOfRange();
145 }
146 } else if (sessionId >= kHostSessionIdBase) {
147 LOGE(
148 "Remote endpoint (%ld, %ld) attempting to start session with "
149 "invalid id %hu",
150 remoteId.hubId, remoteId.id, sessionId);
151 return pw::Status::InvalidArgument();
152 }
153
154 // Prune a stale session with this id if present.
155 std::shared_ptr<HostHub> prunedHostHub;
156 if (auto it = mManager.mIdToSession.find(sessionId);
157 it != mManager.mIdToSession.end()) {
158 SessionStrongRef session(it->second);
159 if (session) {
160 // If the session is in a valid state, prune it if it was not host
161 // initiated and is pending a final ack from message router.
162 if (!hostInitiated && !session.pendingDestination &&
163 session.pendingMessageRouter) {
164 prunedHostHub = std::move(session.hub);
165 } else if (hostInitiated && session.local->id == localId) {
166 LOGE("Hub %ld trying to override its own session %hu", kId, sessionId);
167 return pw::Status::InvalidArgument();
168 } else {
169 LOGE("(host? %d) trying to override session id %hu, hub %ld",
170 hostInitiated, sessionId, kId);
171 return pw::Status::AlreadyExists();
172 }
173 }
174 mManager.mIdToSession.erase(it);
175 }
176
177 // Create and map the new session.
178 mManager.mIdToSession.emplace(
179 std::piecewise_construct, std::forward_as_tuple(sessionId),
180 std::forward_as_tuple(self, local, remote, hostInitiated));
181 return prunedHostHub;
182 }
183
closeSession(uint16_t id)184 pw::Status HostHub::closeSession(uint16_t id) {
185 std::lock_guard lock(mManager.mLock);
186 PW_TRY(checkValidLocked());
187 auto it = mManager.mIdToSession.find(id);
188 if (it == mManager.mIdToSession.end()) {
189 LOGE("Closing unopened session %hu", id);
190 return pw::Status::NotFound();
191 }
192 SessionStrongRef session(it->second);
193 if (session && session.hub->kPid != kPid) {
194 LOGE("Trying to close session %hu for client %d from client %d (hub %ld)",
195 id, session.hub->kPid, kPid, kId);
196 return pw::Status::PermissionDenied();
197 }
198 mManager.mIdToSession.erase(it);
199 return pw::OkStatus();
200 }
201
ackSession(uint16_t id)202 pw::Status HostHub::ackSession(uint16_t id) {
203 return mManager.ackSessionAndGetHostHub(id).status();
204 }
205
checkSessionOpen(uint16_t id)206 pw::Status HostHub::checkSessionOpen(uint16_t id) {
207 return mManager.checkSessionOpenAndGetHostHub(id).status();
208 }
209
id() const210 int64_t HostHub::id() const {
211 std::lock_guard lock(mManager.mLock);
212 return kId;
213 }
214
unlinkFromManager()215 int64_t MessageHubManager::HostHub::unlinkFromManager() {
216 std::lock_guard lock(mManager.mLock);
217 // TODO(b/378545373): Release the session id range.
218 if (kId != kHubIdInvalid) mManager.mIdToHostHub.erase(kId);
219 mManager.mPidToHostHub.erase(kPid);
220 mUnlinked = true;
221 return kId;
222 }
223
unlinkCallbackIfNecessaryLocked()224 void HostHub::unlinkCallbackIfNecessaryLocked() {
225 if (!mCallback) return;
226 if (AIBinder_unlinkToDeath(mCallback->asBinder().get(),
227 mManager.mDeathRecipient.get(),
228 mCookie) != STATUS_OK) {
229 LOGW("Failed to unlink client (pid: %d, hub id: %ld)", kPid, kId);
230 }
231 mCallback.reset();
232 mCookie = nullptr;
233 }
234
checkValidLocked()235 pw::Status HostHub::checkValidLocked() {
236 if (!mCallback) {
237 ALOGW("Endpoint APIs invoked by client %d before callback registered",
238 kPid);
239 return pw::Status::FailedPrecondition();
240 } else if (mUnlinked) {
241 ALOGW("Client %d went down mid-operation", kPid);
242 return pw::Status::Aborted();
243 }
244 return pw::OkStatus();
245 }
246
getEndpointLocked(const EndpointId & id)247 pw::Result<std::shared_ptr<EndpointInfo>> HostHub::getEndpointLocked(
248 const EndpointId &id) {
249 if (id.hubId != kId) {
250 LOGE("Rejecting lookup on unowned endpoint (%ld, %ld) from hub %ld",
251 id.hubId, id.id, kId);
252 return pw::Status::InvalidArgument();
253 }
254 if (auto it = mIdToEndpoint.find(id.id); it != mIdToEndpoint.end())
255 return it->second;
256 return pw::Status::NotFound();
257 }
258
sessionIdInRangeLocked(uint16_t id)259 bool HostHub::sessionIdInRangeLocked(uint16_t id) {
260 for (auto range : mSessionIdRanges) {
261 if (id >= range.first && id <= range.second) return true;
262 }
263 return false;
264 }
265
MessageHubManager(HostHubDownCb cb)266 MessageHubManager::MessageHubManager(HostHubDownCb cb)
267 : mHostHubDownCb(std::move(cb)) {
268 mDeathRecipient = ndk::ScopedAIBinder_DeathRecipient(
269 AIBinder_DeathRecipient_new(onClientDeath));
270 AIBinder_DeathRecipient_setOnUnlinked(
271 mDeathRecipient.get(), /*onUnlinked= */ [](void *cookie) {
272 LOGI("Callback is unlinked. Releasing the death recipient cookie.");
273 delete static_cast<HostHub::DeathRecipientCookie *>(cookie);
274 });
275 }
276
getHostHubByPid(pid_t pid)277 std::shared_ptr<HostHub> MessageHubManager::getHostHubByPid(pid_t pid) {
278 std::lock_guard lock(mLock);
279 if (auto it = mPidToHostHub.find(pid); it != mPidToHostHub.end())
280 return it->second;
281 std::shared_ptr<HostHub> hub(new HostHub(*this, pid));
282 mPidToHostHub.insert({pid, hub});
283 return hub;
284 }
285
getHostHubByEndpointId(const EndpointId & id)286 std::shared_ptr<HostHub> MessageHubManager::getHostHubByEndpointId(
287 const EndpointId &id) {
288 std::lock_guard lock(mLock);
289 if (auto it = mIdToHostHub.find(id.hubId); it != mIdToHostHub.end())
290 return it->second.lock();
291 return {};
292 }
293
294 pw::Result<std::shared_ptr<HostHub>>
checkSessionOpenAndGetHostHub(uint16_t id)295 MessageHubManager::checkSessionOpenAndGetHostHub(uint16_t id) {
296 std::lock_guard lock(mLock);
297 PW_TRY_ASSIGN(SessionStrongRef session, checkSessionLocked(id));
298 if (AIBinder_getCallingPid() != session.hub->kPid) {
299 LOGE("Trying to check unowned session %hu", id);
300 return pw::Status::PermissionDenied();
301 }
302 if (!session.pendingDestination && !session.pendingMessageRouter)
303 return std::move(session.hub);
304 LOGE("Session %hu is pending", id);
305 return pw::Status::FailedPrecondition();
306 }
307
ackSessionAndGetHostHub(uint16_t id)308 pw::Result<std::shared_ptr<HostHub>> MessageHubManager::ackSessionAndGetHostHub(
309 uint16_t id) {
310 std::lock_guard lock(mLock);
311 PW_TRY_ASSIGN(SessionStrongRef session, checkSessionLocked(id));
312 bool isBinderCall = AIBinder_isHandlingTransaction();
313 bool isHostSession = id >= kHostSessionIdBase;
314 if (isBinderCall && AIBinder_getCallingPid() != session.hub->kPid) {
315 LOGE("Trying to ack unowned session %hu", id);
316 return pw::Status::PermissionDenied();
317 } else if (session.pendingDestination) {
318 if (isHostSession == isBinderCall) {
319 LOGE("Session %hu must be acked by other side (host? %d)", id,
320 !isBinderCall);
321 return pw::Status::PermissionDenied();
322 }
323 session.pendingDestination = false;
324 } else if (session.pendingMessageRouter) {
325 if (isBinderCall) {
326 LOGE("Message router must ack session %hu", id);
327 return pw::Status::PermissionDenied();
328 }
329 session.pendingMessageRouter = false;
330 } else {
331 LOGE("Received unexpected ack on session %hu, host: %d", id, isBinderCall);
332 }
333 return std::move(session.hub);
334 }
335
forEachHostHub(std::function<void (HostHub & hub)> fn)336 void MessageHubManager::forEachHostHub(std::function<void(HostHub &hub)> fn) {
337 std::list<std::shared_ptr<HostHub>> hubs;
338 {
339 std::lock_guard lock(mLock);
340 for (auto &[pid, hub] : mPidToHostHub) hubs.push_back(hub);
341 }
342 for (auto &hub : hubs) fn(*hub);
343 }
344
345 pw::Result<MessageHubManager::SessionStrongRef>
checkSessionLocked(uint16_t id)346 MessageHubManager::checkSessionLocked(uint16_t id) {
347 auto sessionIt = mIdToSession.find(id);
348 if (sessionIt == mIdToSession.end()) {
349 LOGE("Did not find expected session %hu", id);
350 return pw::Status::NotFound();
351 }
352 SessionStrongRef session(sessionIt->second);
353 if (!session) {
354 LOGD(
355 "Pruning session %hu due to one or more of host hub, host endpoint, "
356 "or embedded endpoint going down.",
357 id);
358 mIdToSession.erase(sessionIt);
359 return pw::Status::Unavailable();
360 }
361 return std::move(session);
362 }
363
initEmbeddedHubsAndEndpoints(const std::vector<HubInfo> & hubs,const std::vector<EndpointInfo> & endpoints)364 void MessageHubManager::initEmbeddedHubsAndEndpoints(
365 const std::vector<HubInfo> &hubs,
366 const std::vector<EndpointInfo> &endpoints) {
367 std::lock_guard lock(mLock);
368 mIdToEmbeddedHub.clear();
369 for (const auto &hub : hubs) mIdToEmbeddedHub[hub.hubId].info = hub;
370 for (const auto &endpoint : endpoints) addEmbeddedEndpointLocked(endpoint);
371 }
372
addEmbeddedHub(const HubInfo & hub)373 void MessageHubManager::addEmbeddedHub(const HubInfo &hub) {
374 std::lock_guard lock(mLock);
375 if (mIdToEmbeddedHub.count(hub.hubId)) return;
376 mIdToEmbeddedHub[hub.hubId].info = hub;
377 }
378
removeEmbeddedHub(int64_t id)379 std::vector<EndpointId> MessageHubManager::removeEmbeddedHub(int64_t id) {
380 std::lock_guard lock(mLock);
381 std::vector<EndpointId> endpoints;
382 auto it = mIdToEmbeddedHub.find(id);
383 if (it != mIdToEmbeddedHub.end()) {
384 for (const auto &[endpointId, info] : it->second.idToEndpoint)
385 endpoints.push_back({.id = endpointId, .hubId = id});
386 mIdToEmbeddedHub.erase(it);
387 }
388 return endpoints;
389 }
390
getEmbeddedHubs() const391 std::vector<HubInfo> MessageHubManager::getEmbeddedHubs() const {
392 std::lock_guard lock(mLock);
393 std::vector<HubInfo> hubs;
394 for (const auto &[id, hub] : mIdToEmbeddedHub) hubs.push_back(hub.info);
395 return hubs;
396 }
397
addEmbeddedEndpoint(const EndpointInfo & endpoint)398 void MessageHubManager::addEmbeddedEndpoint(const EndpointInfo &endpoint) {
399 std::lock_guard lock(mLock);
400 addEmbeddedEndpointLocked(endpoint);
401 }
402
getEmbeddedEndpoints() const403 std::vector<EndpointInfo> MessageHubManager::getEmbeddedEndpoints() const {
404 std::lock_guard lock(mLock);
405 std::vector<EndpointInfo> endpoints;
406 for (const auto &[id, hub] : mIdToEmbeddedHub) {
407 for (const auto &[endptId, endptInfo] : hub.idToEndpoint)
408 endpoints.push_back(*endptInfo);
409 }
410 return endpoints;
411 }
412
onClientDeath(void * cookie)413 void MessageHubManager::onClientDeath(void *cookie) {
414 auto *cookieData = reinterpret_cast<HostHub::DeathRecipientCookie *>(cookie);
415 MessageHubManager *manager = cookieData->manager;
416 std::shared_ptr<HostHub> hub = manager->getHostHubByPid(cookieData->pid);
417 LOGW("Hub %ld (pid %d) died", hub->id(), cookieData->pid);
418 manager->mHostHubDownCb(hub->unlinkFromManager());
419 }
420
addEmbeddedEndpointLocked(const EndpointInfo & endpoint)421 void MessageHubManager::addEmbeddedEndpointLocked(
422 const EndpointInfo &endpoint) {
423 auto it = mIdToEmbeddedHub.find(endpoint.id.hubId);
424 if (it == mIdToEmbeddedHub.end()) {
425 LOGW("Could not find hub %ld for endpoint %ld", endpoint.id.hubId,
426 endpoint.id.id);
427 return;
428 }
429 it->second.idToEndpoint.insert(
430 {endpoint.id.id, std::make_shared<EndpointInfo>(endpoint)});
431 }
432
433 pw::Result<std::shared_ptr<EndpointInfo>>
getEmbeddedEndpointLocked(const EndpointId & id)434 MessageHubManager::getEmbeddedEndpointLocked(const EndpointId &id) {
435 auto hubIt = mIdToEmbeddedHub.find(id.hubId);
436 if (hubIt != mIdToEmbeddedHub.end()) {
437 auto it = hubIt->second.idToEndpoint.find(id.id);
438 if (it != hubIt->second.idToEndpoint.end()) return it->second;
439 }
440 LOGW("Could not find remote endpoint (%ld, %ld)", id.hubId, id.id);
441 return pw::Status::NotFound();
442 }
443
444 } // namespace android::hardware::contexthub::common::implementation
445