xref: /aosp_15_r20/system/chre/host/hal_generic/common/message_hub_manager.cc (revision 84e339476a462649f82315436d70fd732297a399)
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "message_hub_manager.h"
18 
19 #include <unistd.h>
20 
21 #include <cstdint>
22 #include <functional>
23 #include <list>
24 #include <memory>
25 #include <optional>
26 #include <string>
27 #include <unordered_map>
28 #include <utility>
29 #include <vector>
30 
31 #include <aidl/android/hardware/contexthub/BnContextHub.h>
32 
33 #include "chre_host/log.h"
34 #include "pw_result/result.h"
35 #include "pw_status/status.h"
36 #include "pw_status/try.h"
37 
38 namespace android::hardware::contexthub::common::implementation {
39 
40 using HostHub = MessageHubManager::HostHub;
41 
~HostHub()42 HostHub::~HostHub() {
43   std::lock_guard lock(mManager.mLock);
44   unlinkCallbackIfNecessaryLocked();
45 }
46 
setCallback(std::shared_ptr<IEndpointCallback> callback)47 pw::Status HostHub::setCallback(std::shared_ptr<IEndpointCallback> callback) {
48   std::lock_guard lock(mManager.mLock);
49   auto *cookie = new DeathRecipientCookie{&mManager, kPid};
50   if (AIBinder_linkToDeath(callback->asBinder().get(),
51                            mManager.mDeathRecipient.get(),
52                            cookie) != STATUS_OK) {
53     LOGE("Failed to link callback for hub %ld (pid %d) to death recipient", kId,
54          kPid);
55     delete cookie;
56     return pw::Status::Internal();
57   }
58   unlinkCallbackIfNecessaryLocked();
59   mCallback = std::move(callback);
60   mCookie = cookie;
61   return pw::OkStatus();
62 }
63 
getCallback() const64 std::shared_ptr<IEndpointCallback> HostHub::getCallback() const {
65   std::lock_guard lock(mManager.mLock);
66   return mCallback;
67 }
68 
addEndpoint(std::weak_ptr<HostHub> self,const EndpointInfo & info)69 pw::Status HostHub::addEndpoint(std::weak_ptr<HostHub> self,
70                                 const EndpointInfo &info) {
71   std::lock_guard lock(mManager.mLock);
72   PW_TRY(checkValidLocked());
73   int64_t id = info.id.id;
74   if (auto it = mIdToEndpoint.find(id); it != mIdToEndpoint.end()) {
75     LOGE("Endpoint %ld already exists in hub %ld (pid %d)", id, kId, kPid);
76     return pw::Status::AlreadyExists();
77   }
78   if (kId == kHubIdInvalid) {
79     // If this is the hub's first endpoint, store its hub id and register it
80     // with the manager.
81     if (info.id.hubId == kContextHubServiceHubId &&
82         AIBinder_getCallingUid() != kSystemServerUid) {
83       LOGW(
84           "Non-systemserver client (pid %d) trying to register as "
85           "ContextHubService",
86           kPid);
87       return pw::Status::InvalidArgument();
88     }
89     kId = info.id.hubId;
90     mManager.mIdToHostHub.insert({kId, self});
91   }
92   mIdToEndpoint.insert({id, std::make_unique<EndpointInfo>(info)});
93   return pw::OkStatus();
94 }
95 
removeEndpoint(const EndpointId & id)96 pw::Status HostHub::removeEndpoint(const EndpointId &id) {
97   std::lock_guard lock(mManager.mLock);
98   PW_TRY(checkValidLocked());
99   if (auto it = mIdToEndpoint.find(id.id); it != mIdToEndpoint.end()) {
100     mIdToEndpoint.erase(it);
101     return pw::OkStatus();
102   }
103   LOGE("Client (hub %ld, pid %d) tried to remove unknown endpoint %ld", kId,
104        kPid, id.id);
105   return pw::Status::NotFound();
106 }
107 
reserveSessionIdRange(uint16_t size)108 pw::Result<std::pair<uint16_t, uint16_t>> HostHub::reserveSessionIdRange(
109     uint16_t size) {
110   std::lock_guard lock(mManager.mLock);
111   PW_TRY(checkValidLocked());
112   if (!size || size > kSessionIdMaxRange) {
113     LOGE("Client (hub %ld, pid %d) tried to allocate %hu session ids", kId,
114          kPid, size);
115     return pw::Status::InvalidArgument();
116   }
117   if (USHRT_MAX - mManager.mNextSessionId + 1 < size) {
118     LOGW("Could not allocate %hu session ids, ids exhausted", size);
119     return pw::Status::ResourceExhausted();
120   }
121   mSessionIdRanges.push_back(
122       {mManager.mNextSessionId, mManager.mNextSessionId + size - 1});
123   mManager.mNextSessionId += size;
124   return mSessionIdRanges.back();
125 }
126 
openSession(std::weak_ptr<HostHub> self,const EndpointId & localId,const EndpointId & remoteId,uint16_t sessionId)127 pw::Result<std::shared_ptr<HostHub>> HostHub::openSession(
128     std::weak_ptr<HostHub> self, const EndpointId &localId,
129     const EndpointId &remoteId, uint16_t sessionId) {
130   std::lock_guard lock(mManager.mLock);
131   PW_TRY(checkValidLocked());
132 
133   // Lookup the endpoints.
134   PW_TRY_ASSIGN(std::shared_ptr<EndpointInfo> local,
135                 getEndpointLocked(localId));
136   PW_TRY_ASSIGN(std::shared_ptr<EndpointInfo> remote,
137                 mManager.getEmbeddedEndpointLocked(remoteId));
138 
139   // Validate the session id.
140   bool hostInitiated = AIBinder_isHandlingTransaction();
141   if (hostInitiated) {
142     if (!sessionIdInRangeLocked(sessionId)) {
143       LOGE("Session id %hu out of range for hub %ld", sessionId, kId);
144       return pw::Status::OutOfRange();
145     }
146   } else if (sessionId >= kHostSessionIdBase) {
147     LOGE(
148         "Remote endpoint (%ld, %ld) attempting to start session with "
149         "invalid id %hu",
150         remoteId.hubId, remoteId.id, sessionId);
151     return pw::Status::InvalidArgument();
152   }
153 
154   // Prune a stale session with this id if present.
155   std::shared_ptr<HostHub> prunedHostHub;
156   if (auto it = mManager.mIdToSession.find(sessionId);
157       it != mManager.mIdToSession.end()) {
158     SessionStrongRef session(it->second);
159     if (session) {
160       // If the session is in a valid state, prune it if it was not host
161       // initiated and is pending a final ack from message router.
162       if (!hostInitiated && !session.pendingDestination &&
163           session.pendingMessageRouter) {
164         prunedHostHub = std::move(session.hub);
165       } else if (hostInitiated && session.local->id == localId) {
166         LOGE("Hub %ld trying to override its own session %hu", kId, sessionId);
167         return pw::Status::InvalidArgument();
168       } else {
169         LOGE("(host? %d) trying to override session id %hu, hub %ld",
170              hostInitiated, sessionId, kId);
171         return pw::Status::AlreadyExists();
172       }
173     }
174     mManager.mIdToSession.erase(it);
175   }
176 
177   // Create and map the new session.
178   mManager.mIdToSession.emplace(
179       std::piecewise_construct, std::forward_as_tuple(sessionId),
180       std::forward_as_tuple(self, local, remote, hostInitiated));
181   return prunedHostHub;
182 }
183 
closeSession(uint16_t id)184 pw::Status HostHub::closeSession(uint16_t id) {
185   std::lock_guard lock(mManager.mLock);
186   PW_TRY(checkValidLocked());
187   auto it = mManager.mIdToSession.find(id);
188   if (it == mManager.mIdToSession.end()) {
189     LOGE("Closing unopened session %hu", id);
190     return pw::Status::NotFound();
191   }
192   SessionStrongRef session(it->second);
193   if (session && session.hub->kPid != kPid) {
194     LOGE("Trying to close session %hu for client %d from client %d (hub %ld)",
195          id, session.hub->kPid, kPid, kId);
196     return pw::Status::PermissionDenied();
197   }
198   mManager.mIdToSession.erase(it);
199   return pw::OkStatus();
200 }
201 
ackSession(uint16_t id)202 pw::Status HostHub::ackSession(uint16_t id) {
203   return mManager.ackSessionAndGetHostHub(id).status();
204 }
205 
checkSessionOpen(uint16_t id)206 pw::Status HostHub::checkSessionOpen(uint16_t id) {
207   return mManager.checkSessionOpenAndGetHostHub(id).status();
208 }
209 
id() const210 int64_t HostHub::id() const {
211   std::lock_guard lock(mManager.mLock);
212   return kId;
213 }
214 
unlinkFromManager()215 int64_t MessageHubManager::HostHub::unlinkFromManager() {
216   std::lock_guard lock(mManager.mLock);
217   // TODO(b/378545373): Release the session id range.
218   if (kId != kHubIdInvalid) mManager.mIdToHostHub.erase(kId);
219   mManager.mPidToHostHub.erase(kPid);
220   mUnlinked = true;
221   return kId;
222 }
223 
unlinkCallbackIfNecessaryLocked()224 void HostHub::unlinkCallbackIfNecessaryLocked() {
225   if (!mCallback) return;
226   if (AIBinder_unlinkToDeath(mCallback->asBinder().get(),
227                              mManager.mDeathRecipient.get(),
228                              mCookie) != STATUS_OK) {
229     LOGW("Failed to unlink client (pid: %d, hub id: %ld)", kPid, kId);
230   }
231   mCallback.reset();
232   mCookie = nullptr;
233 }
234 
checkValidLocked()235 pw::Status HostHub::checkValidLocked() {
236   if (!mCallback) {
237     ALOGW("Endpoint APIs invoked by client %d before callback registered",
238           kPid);
239     return pw::Status::FailedPrecondition();
240   } else if (mUnlinked) {
241     ALOGW("Client %d went down mid-operation", kPid);
242     return pw::Status::Aborted();
243   }
244   return pw::OkStatus();
245 }
246 
getEndpointLocked(const EndpointId & id)247 pw::Result<std::shared_ptr<EndpointInfo>> HostHub::getEndpointLocked(
248     const EndpointId &id) {
249   if (id.hubId != kId) {
250     LOGE("Rejecting lookup on unowned endpoint (%ld, %ld) from hub %ld",
251          id.hubId, id.id, kId);
252     return pw::Status::InvalidArgument();
253   }
254   if (auto it = mIdToEndpoint.find(id.id); it != mIdToEndpoint.end())
255     return it->second;
256   return pw::Status::NotFound();
257 }
258 
sessionIdInRangeLocked(uint16_t id)259 bool HostHub::sessionIdInRangeLocked(uint16_t id) {
260   for (auto range : mSessionIdRanges) {
261     if (id >= range.first && id <= range.second) return true;
262   }
263   return false;
264 }
265 
MessageHubManager(HostHubDownCb cb)266 MessageHubManager::MessageHubManager(HostHubDownCb cb)
267     : mHostHubDownCb(std::move(cb)) {
268   mDeathRecipient = ndk::ScopedAIBinder_DeathRecipient(
269       AIBinder_DeathRecipient_new(onClientDeath));
270   AIBinder_DeathRecipient_setOnUnlinked(
271       mDeathRecipient.get(), /*onUnlinked= */ [](void *cookie) {
272         LOGI("Callback is unlinked. Releasing the death recipient cookie.");
273         delete static_cast<HostHub::DeathRecipientCookie *>(cookie);
274       });
275 }
276 
getHostHubByPid(pid_t pid)277 std::shared_ptr<HostHub> MessageHubManager::getHostHubByPid(pid_t pid) {
278   std::lock_guard lock(mLock);
279   if (auto it = mPidToHostHub.find(pid); it != mPidToHostHub.end())
280     return it->second;
281   std::shared_ptr<HostHub> hub(new HostHub(*this, pid));
282   mPidToHostHub.insert({pid, hub});
283   return hub;
284 }
285 
getHostHubByEndpointId(const EndpointId & id)286 std::shared_ptr<HostHub> MessageHubManager::getHostHubByEndpointId(
287     const EndpointId &id) {
288   std::lock_guard lock(mLock);
289   if (auto it = mIdToHostHub.find(id.hubId); it != mIdToHostHub.end())
290     return it->second.lock();
291   return {};
292 }
293 
294 pw::Result<std::shared_ptr<HostHub>>
checkSessionOpenAndGetHostHub(uint16_t id)295 MessageHubManager::checkSessionOpenAndGetHostHub(uint16_t id) {
296   std::lock_guard lock(mLock);
297   PW_TRY_ASSIGN(SessionStrongRef session, checkSessionLocked(id));
298   if (AIBinder_getCallingPid() != session.hub->kPid) {
299     LOGE("Trying to check unowned session %hu", id);
300     return pw::Status::PermissionDenied();
301   }
302   if (!session.pendingDestination && !session.pendingMessageRouter)
303     return std::move(session.hub);
304   LOGE("Session %hu is pending", id);
305   return pw::Status::FailedPrecondition();
306 }
307 
ackSessionAndGetHostHub(uint16_t id)308 pw::Result<std::shared_ptr<HostHub>> MessageHubManager::ackSessionAndGetHostHub(
309     uint16_t id) {
310   std::lock_guard lock(mLock);
311   PW_TRY_ASSIGN(SessionStrongRef session, checkSessionLocked(id));
312   bool isBinderCall = AIBinder_isHandlingTransaction();
313   bool isHostSession = id >= kHostSessionIdBase;
314   if (isBinderCall && AIBinder_getCallingPid() != session.hub->kPid) {
315     LOGE("Trying to ack unowned session %hu", id);
316     return pw::Status::PermissionDenied();
317   } else if (session.pendingDestination) {
318     if (isHostSession == isBinderCall) {
319       LOGE("Session %hu must be acked by other side (host? %d)", id,
320            !isBinderCall);
321       return pw::Status::PermissionDenied();
322     }
323     session.pendingDestination = false;
324   } else if (session.pendingMessageRouter) {
325     if (isBinderCall) {
326       LOGE("Message router must ack session %hu", id);
327       return pw::Status::PermissionDenied();
328     }
329     session.pendingMessageRouter = false;
330   } else {
331     LOGE("Received unexpected ack on session %hu, host: %d", id, isBinderCall);
332   }
333   return std::move(session.hub);
334 }
335 
forEachHostHub(std::function<void (HostHub & hub)> fn)336 void MessageHubManager::forEachHostHub(std::function<void(HostHub &hub)> fn) {
337   std::list<std::shared_ptr<HostHub>> hubs;
338   {
339     std::lock_guard lock(mLock);
340     for (auto &[pid, hub] : mPidToHostHub) hubs.push_back(hub);
341   }
342   for (auto &hub : hubs) fn(*hub);
343 }
344 
345 pw::Result<MessageHubManager::SessionStrongRef>
checkSessionLocked(uint16_t id)346 MessageHubManager::checkSessionLocked(uint16_t id) {
347   auto sessionIt = mIdToSession.find(id);
348   if (sessionIt == mIdToSession.end()) {
349     LOGE("Did not find expected session %hu", id);
350     return pw::Status::NotFound();
351   }
352   SessionStrongRef session(sessionIt->second);
353   if (!session) {
354     LOGD(
355         "Pruning session %hu due to one or more of host hub, host endpoint, "
356         "or embedded endpoint going down.",
357         id);
358     mIdToSession.erase(sessionIt);
359     return pw::Status::Unavailable();
360   }
361   return std::move(session);
362 }
363 
initEmbeddedHubsAndEndpoints(const std::vector<HubInfo> & hubs,const std::vector<EndpointInfo> & endpoints)364 void MessageHubManager::initEmbeddedHubsAndEndpoints(
365     const std::vector<HubInfo> &hubs,
366     const std::vector<EndpointInfo> &endpoints) {
367   std::lock_guard lock(mLock);
368   mIdToEmbeddedHub.clear();
369   for (const auto &hub : hubs) mIdToEmbeddedHub[hub.hubId].info = hub;
370   for (const auto &endpoint : endpoints) addEmbeddedEndpointLocked(endpoint);
371 }
372 
addEmbeddedHub(const HubInfo & hub)373 void MessageHubManager::addEmbeddedHub(const HubInfo &hub) {
374   std::lock_guard lock(mLock);
375   if (mIdToEmbeddedHub.count(hub.hubId)) return;
376   mIdToEmbeddedHub[hub.hubId].info = hub;
377 }
378 
removeEmbeddedHub(int64_t id)379 std::vector<EndpointId> MessageHubManager::removeEmbeddedHub(int64_t id) {
380   std::lock_guard lock(mLock);
381   std::vector<EndpointId> endpoints;
382   auto it = mIdToEmbeddedHub.find(id);
383   if (it != mIdToEmbeddedHub.end()) {
384     for (const auto &[endpointId, info] : it->second.idToEndpoint)
385       endpoints.push_back({.id = endpointId, .hubId = id});
386     mIdToEmbeddedHub.erase(it);
387   }
388   return endpoints;
389 }
390 
getEmbeddedHubs() const391 std::vector<HubInfo> MessageHubManager::getEmbeddedHubs() const {
392   std::lock_guard lock(mLock);
393   std::vector<HubInfo> hubs;
394   for (const auto &[id, hub] : mIdToEmbeddedHub) hubs.push_back(hub.info);
395   return hubs;
396 }
397 
addEmbeddedEndpoint(const EndpointInfo & endpoint)398 void MessageHubManager::addEmbeddedEndpoint(const EndpointInfo &endpoint) {
399   std::lock_guard lock(mLock);
400   addEmbeddedEndpointLocked(endpoint);
401 }
402 
getEmbeddedEndpoints() const403 std::vector<EndpointInfo> MessageHubManager::getEmbeddedEndpoints() const {
404   std::lock_guard lock(mLock);
405   std::vector<EndpointInfo> endpoints;
406   for (const auto &[id, hub] : mIdToEmbeddedHub) {
407     for (const auto &[endptId, endptInfo] : hub.idToEndpoint)
408       endpoints.push_back(*endptInfo);
409   }
410   return endpoints;
411 }
412 
onClientDeath(void * cookie)413 void MessageHubManager::onClientDeath(void *cookie) {
414   auto *cookieData = reinterpret_cast<HostHub::DeathRecipientCookie *>(cookie);
415   MessageHubManager *manager = cookieData->manager;
416   std::shared_ptr<HostHub> hub = manager->getHostHubByPid(cookieData->pid);
417   LOGW("Hub %ld (pid %d) died", hub->id(), cookieData->pid);
418   manager->mHostHubDownCb(hub->unlinkFromManager());
419 }
420 
addEmbeddedEndpointLocked(const EndpointInfo & endpoint)421 void MessageHubManager::addEmbeddedEndpointLocked(
422     const EndpointInfo &endpoint) {
423   auto it = mIdToEmbeddedHub.find(endpoint.id.hubId);
424   if (it == mIdToEmbeddedHub.end()) {
425     LOGW("Could not find hub %ld for endpoint %ld", endpoint.id.hubId,
426          endpoint.id.id);
427     return;
428   }
429   it->second.idToEndpoint.insert(
430       {endpoint.id.id, std::make_shared<EndpointInfo>(endpoint)});
431 }
432 
433 pw::Result<std::shared_ptr<EndpointInfo>>
getEmbeddedEndpointLocked(const EndpointId & id)434 MessageHubManager::getEmbeddedEndpointLocked(const EndpointId &id) {
435   auto hubIt = mIdToEmbeddedHub.find(id.hubId);
436   if (hubIt != mIdToEmbeddedHub.end()) {
437     auto it = hubIt->second.idToEndpoint.find(id.id);
438     if (it != hubIt->second.idToEndpoint.end()) return it->second;
439   }
440   LOGW("Could not find remote endpoint (%ld, %ld)", id.hubId, id.id);
441   return pw::Status::NotFound();
442 }
443 
444 }  // namespace android::hardware::contexthub::common::implementation
445